借助 Queue 实现多线程间的协同
Pipeline
并行地执行多个任务的 Python 程序通常都需要一种协作机制,使得多个线程负责的各部分之间的工作能够相互协同。
其中一种协作机制称为管线(pipeline)。pipeline 的工作方式类似于工厂里的流水线,分为串行排列的多道工序(phase)。每道工序都由特定的函数处理,函数之间可以并行地执行。
比如需要创建这样一个系统,可以从相机接收持续的图片流,再将收到的图片更改尺寸,最后上传到线上的图片库中。
这样的系统就可以分为三道工序,分别用 download
、resize
、upload
三个函数去处理。此外还需要一个在各道工序间传递任务对象的媒介,这个可以通过线程安全的 producer-consumer 队列去实现。
具体的示例代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79import time
from threading import Thread
from collections import deque
from threading import Lock
def upload(item):
pass
def download(item):
pass
def resize(item):
pass
class MyQueue:
def __init__(self) -> None:
self.items = deque()
self.lock = Lock()
def put(self, item):
with self.lock:
self.items.append(item)
def get(self):
with self.lock:
return self.items.popleft()
class Worker(Thread):
def __init__(self, func, in_queue, out_queue):
super().__init__()
self.func = func
self.in_queue = in_queue
self.out_queue = out_queue
self.polled_count = 0
# self.work_done = 0
def run(self):
while True:
self.polled_count += 1
try:
item = self.in_queue.get()
except IndexError:
time.sleep(0.01)
else:
result = self.func(item)
self.out_queue.put(result)
# self.work_done += 1
download_queue = MyQueue()
resize_queue = MyQueue()
upload_queue = MyQueue()
done_queue = MyQueue()
threads = [
Worker(download, download_queue, resize_queue),
Worker(resize, resize_queue, upload_queue),
Worker(upload, upload_queue, done_queue),
]
for thread in threads:
thread.start()
for i in range(100):
download_queue.put(object())
while len(done_queue.items) < 100:
pass
processed = len(done_queue.items)
polled = sum(t.polled_count for t in threads)
print(f'Processed {processed} items after '
f'polling {polled} times')
# Processed 100 items after polling 308 times
上述实现虽然能够处理完成输入的所有任务,但仍存在很多问题。
首先是 polled_count
值远大于任务的数量。即工作线程的 run
方法中定义的从队列中取项目的动作执行了太多次。
各个工作函数的执行速度其实是不一致的,前置位的工作函数(比如 download
)运行缓慢,会导致后一道工序(比如 resize
)上的函数持续不断的向其队列请求新的任务,然而队列为空导致不断地触发 IndexError
错误,最终导致 CPU 时间的浪费。
其次,确认所有任务是否全部完成,需要一个 while
循环不断地检查 done_queue
队列中元素的数量。
再次,工作线程中的 run
方法会一直处于 while True
的循环当中,没有一种明显的方法可以向该工作线程发送任务完成可以退出的消息。
最后,当第一道工序执行很快而第二道工序执行很慢时,处于两道工序之间的队列中的元素数量会持续增长。如果有足够多的任务和足够长的时间,程序最终会耗尽内存并崩溃。
Queue
内置的 queue
模块中的 Queue
类可以解决上述问题。
Queue
类中的 get
方法是阻塞的,即在有新的项目放置到队列中以前,get
会一直处于等待状态,直到获取到某个项目。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25import time
from queue import Queue
from threading import Thread
my_queue = Queue()
def consumer():
print('Consumer waiting')
my_queue.get()
print('Consumer done')
thread = Thread(target=consumer)
thread.start()
time.sleep(1)
print('Producer putting')
my_queue.put(object())
print('Producer done')
thread.join()
# Consumer waiting
# Producer putting
# Producer done
# Consumer done
即便线程先于主程序运行,它也会先处于等待状态,直到一个新的项目被放置到队列中,能够被 get
获取到。
这可以解决前面的程序中 polled_count
值过大的问题。
Queue
类可以指定 buffer size,从而限制了两道工序间 pending 的任务的最大数量。即队列中的元素数量达到最大值后,向队列中放入新元素的 put
方法会阻塞,等待队列中某个元素被消耗从而为新元素腾出空间。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31import time
from threading import Thread
from queue import Queue
my_queue = Queue(1)
def consumer():
time.sleep(1)
my_queue.get()
print('Consumer got 1')
my_queue.get()
print('Consumer got 2')
print('Consumer done')
thread = Thread(target=consumer)
thread.start()
my_queue.put(object())
print('Producer put 1')
my_queue.put(object())
print('Producer put 2')
print('Producer done')
thread.join()
# Producer put 1
# Consumer got 1
# Producer put 2
# Producer done
# Consumer got 2
# Consumer done
Consumer 线程中的 sleep
应该使得主程序有足够的时间将两个对象都放置到队列中。但队列的大小是 1,就导致队列中先放入的元素必须通过 get
方法取出之后,才能继续使用 put
方法放置新的元素进去。
即 Producer 会等待 Consumer 线程把放置到队列中的旧元素消耗掉,才能继续向队列中添加新的元素。
task_done
Queue
类可以使用其 task_done
方法来追踪任务的进度,使得程序可以确保在某个特定的时间点,队列中的所有任务都已经被处理完成。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31from queue import Queue
from threading import Thread
import time
in_queue = Queue()
def consumer():
print('Consumer waiting')
work = in_queue.get()
print('Consumer working')
print('Consumer done')
in_queue.task_done()
thread = Thread(target=consumer)
thread.start()
print('Producer putting')
in_queue.put(object())
print('Producer waiting')
in_queue.join()
print('Producer done')
thread.join()
# Consumer waiting
# Producer putting
# Producer waiting
# Consumer working
# Consumer done
# Producer done
在代码中调用 in_queue.join()
后,只有队列 in_queue
中的所有元素都执行了一遍 task_done
(即有几个元素就需要几条 task_done
),in_queue.join()
之后的代码才会执行。否则就继续等待,直到 Consumer 调用了足够次数的 task_done
。
结合前面提到的特性,可以创建一个新的 Queue
类,它能够告知工作线程什么时候该停止执行。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15class ClosableQueue(Queue):
SENTINEL = object()
def close(self):
self.put(self.SENTINEL)
def __iter__(self):
while True:
item = self.get()
try:
if item is self.SENTINEL:
return # Cause the thread to exit
yield item
finally:
self.task_done()
更新后的完整代码如下:1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74import time
from queue import Queue
from threading import Thread
from collections import deque
from threading import Lock
def upload(item):
pass
def download(item):
pass
def resize(item):
pass
class ClosableQueue(Queue):
SENTINEL = object()
def close(self):
self.put(self.SENTINEL)
def __iter__(self):
while True:
item = self.get()
try:
if item is self.SENTINEL:
return # Cause the thread to exit
yield item
finally:
self.task_done()
class StoppableWorker(Thread):
def __init__(self, func, in_queue, out_queue):
super().__init__()
self.func = func
self.in_queue = in_queue
self.out_queue = out_queue
def run(self):
for item in self.in_queue:
result = self.func(item)
self.out_queue.put(result)
download_queue = ClosableQueue()
resize_queue = ClosableQueue()
upload_queue = ClosableQueue()
done_queue = ClosableQueue()
threads = [
StoppableWorker(download, download_queue, resize_queue),
StoppableWorker(resize, resize_queue, upload_queue),
StoppableWorker(upload, upload_queue, done_queue),
]
for thread in threads:
thread.start()
for _ in range(1000):
download_queue.put(object())
download_queue.close()
download_queue.join()
resize_queue.close()
resize_queue.join()
upload_queue.close()
upload_queue.join()
print(done_queue.qsize(), 'items finished')
# 1000 items finished
逻辑上就是给 Queue
类加了一个 SENTINEL
对象,用来作为队列结束的标志。工作线程通过循环读取输入队列中的任务,这些任务对象经过特定函数处理后放置到输出队列中。若读取到的任务是 SENTINEL
对象,则线程结束运行。
task_done
方法和主程序中的 xxx_queue.join
用于确保某个队列中的所有任务都已经处理完成,转移到了下一个队列中。后面再调用下一个队列的 close
方法在该队列尾部添加一个 SENTINEL
对象,作为队列的结束标志。
上述实现的好处在于,工作线程会在读取到 SENTINEL
对象时自动结束运行;主程序中 upload_queue.join()
执行结束后就能保证三个阶段的所有任务都被处理完了,而不再需要频繁地去检查 done_queue
中的元素数量。
最终实现
当需要对不同的阶段(download
、resize
、upload
)都分别绑定多个线程去处理时,只稍微修改下代码就可以了。1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87import time
from queue import Queue
from threading import Thread
from collections import deque
from threading import Lock
def upload(item):
pass
def download(item):
pass
def resize(item):
pass
class ClosableQueue(Queue):
SENTINEL = object()
def close(self):
self.put(self.SENTINEL)
def __iter__(self):
while True:
item = self.get()
try:
if item is self.SENTINEL:
return # Cause the thread to exit
yield item
finally:
self.task_done()
class StoppableWorker(Thread):
def __init__(self, func, in_queue, out_queue):
super().__init__()
self.func = func
self.in_queue = in_queue
self.out_queue = out_queue
def run(self):
for item in self.in_queue:
result = self.func(item)
self.out_queue.put(result)
def start_threads(count, *args):
threads = [StoppableWorker(*args) for _ in range(count)]
for thread in threads:
thread.start()
return threads
def stop_threads(closable_queue, threads):
for _ in threads:
closable_queue.close()
closable_queue.join()
for thread in threads:
thread.join()
download_queue = ClosableQueue()
resize_queue = ClosableQueue()
upload_queue = ClosableQueue()
done_queue = ClosableQueue()
download_threads = start_threads(
3, download, download_queue, resize_queue)
resize_threads = start_threads(
4, resize, resize_queue, upload_queue)
upload_threads = start_threads(
5, upload, upload_queue, done_queue)
for _ in range(1000):
download_queue.put(object())
stop_threads(download_queue, download_threads)
stop_threads(resize_queue, resize_threads)
stop_threads(upload_queue, upload_threads)
print(done_queue.qsize(), 'items finished')
# 1000 items finished
要点
- Pipeline 可以很好地组织流水线类型的工作,尤其是 IO 相关的 Python 多线程程序
- 需要特别注意构建 pipeline 时的隐藏问题:怎样告诉工作线程终止运行、busy waiting 以及潜在的内存爆炸等
Queue
类具备构建健壮的 pipeline 所需的特性,如阻塞式操作、buffer size 和 joining 等。