-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathiterablequeue.py
More file actions
233 lines (197 loc) · 6.98 KB
/
iterablequeue.py
File metadata and controls
233 lines (197 loc) · 6.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
from asyncio import Queue, QueueFull, QueueEmpty, Event, Lock
from typing import AsyncIterable, TypeVar, Optional
from .utils import Countable
import logging
# Setup logging
logger = logging.getLogger()
error = logger.error
message = logger.warning
verbose = logger.info
debug = logger.debug
T = TypeVar("T")
class QueueDone(Exception):
pass
class IterableQueue(Queue[T], AsyncIterable[T], Countable):
"""Async.Queue subclass with automatic termination when the queue has been
filled and emptied. Supports:
- Queue() interface except _nowait() methods
- AsyncIterable(): async for item in queue.get():
- Automatic termination of the consumers when the queue has been emptied (QueueDone exception)
- Producers must be registered with add_producer() and they must notify
once they have finished adding items with finish()
- Countable interface to count number of items task_done() through 'count' property
- Countable property can be disabled with count_items=False. This is useful when you
want to sum the count of multiple IterableQueues"""
def __init__(self, count_items: bool = True, **kwargs):
# _Q is required instead of inheriting from Queue()
# using super() since Queue is Optional[T], not [T]
self._Q: Queue[Optional[T]] = Queue(**kwargs)
self._producers: int = 0
self._count_items: bool = count_items
self._count: int = 0
self._wip: int = 0
self._modify: Lock = Lock()
self._put_lock: Lock = Lock()
self._filled: Event = Event()
self._empty: Event = Event()
self._done: Event = Event()
self._empty.set()
@property
def is_filled(self) -> bool:
return self._filled.is_set()
@property
def is_done(self) -> bool:
return self.is_filled and self.empty() and not self.has_wip
@property
def maxsize(self) -> int:
return self._Q.maxsize
@property
def _maxsize(self) -> int:
return self.maxsize
def full(self) -> bool:
return self._Q.full()
def check_done(self) -> bool:
if self.is_filled and self.empty() and not self.has_wip:
self._done.set()
return True
return False
def empty(self) -> bool:
"""Queue has not items except None as sentinel"""
return self._empty.is_set() or self.qsize() == 0
def qsize(self) -> int:
if self.is_filled:
return self._Q.qsize() - 1
else:
return self._Q.qsize()
@property
def wip(self) -> int:
return self._wip
@property
def has_wip(self) -> bool:
return self._wip > 0
@property
def count(self) -> int:
if self._count_items:
return self._count
else:
return 0
async def add_producer(self, N: int = 1) -> int:
"""Add producer(s) to the queue"""
assert N > 0, "N has to be positive"
async with self._modify:
if self.is_filled:
raise QueueDone
self._producers += N
return self._producers
async def finish(self, all: bool = False, empty: bool = False) -> bool:
"""Producer has finished adding items to the queue
* all: finish() queue for all producers at once"""
async with self._modify:
if self._producers <= 0 or self.is_filled:
# raise ValueError("finish() called more than the is producers")
self._producers = 0
return False
self._producers -= 1
if all or self._producers <= 0:
self._filled.set()
self._producers = 0
if self._producers == 0:
if empty:
try:
while True:
_ = self.get_nowait()
self.task_done()
except (QueueDone, QueueEmpty):
pass
async with self._put_lock:
if empty:
try:
_ = self.get_nowait()
self.task_done()
except (QueueDone, QueueEmpty):
pass
self.check_done()
await self._Q.put(None)
return True
return False
async def put(self, item: T) -> None:
async with self._put_lock:
if self.is_filled: # should this be inside put_lock?
raise QueueDone
if self._producers <= 0:
raise ValueError("No registered producers")
elif item is None:
raise ValueError("Cannot add None to IterableQueue")
await self._Q.put(item=item)
self._empty.clear()
return None
def put_nowait(self, item: T) -> None:
"""Attempt to implement put_nowait()"""
# raise NotImplementedError
if self.is_filled:
raise QueueDone
if self._producers <= 0:
raise ValueError("No registered producers")
elif item is None:
raise ValueError("Cannot add None to IterableQueue")
self._Q.put_nowait(item=item)
self._empty.clear()
return None
async def get(self) -> T:
item = await self._Q.get()
if item is None:
self._empty.set()
self._Q.task_done()
self.check_done()
async with self._put_lock:
await self._Q.put(None)
raise QueueDone
else:
if self._Q.qsize() == 0:
self._empty.set()
async with self._modify:
self._wip += 1
return item
def get_nowait(self) -> T:
"""Attempt to implement get_nowait()"""
# raise NotImplementedError
item: T | None = self._Q.get_nowait()
if item is None:
self._empty.set()
self._Q.task_done()
self.check_done()
try:
self._Q.put_nowait(None)
except QueueFull:
pass
raise QueueDone
else:
if self._Q.qsize() == 0:
self._empty.set()
self._wip += 1
return item
def task_done(self) -> None:
self._Q.task_done()
self._count += 1
self._wip -= 1
if self._wip < 0:
raise ValueError("task_done() called more than tasks open")
self.check_done()
async def join(self) -> None:
debug("Waiting queue to be filled")
await self._filled.wait()
debug("Queue filled, waiting when queue is done")
await self._done.wait()
debug("queue is done")
return None
def __aiter__(self):
return self
async def __anext__(self) -> T:
async with self._modify:
if self._wip > 0: # do not mark task_done() at first call
self.task_done()
try:
item = await self.get()
return item
except QueueDone:
raise StopAsyncIteration