Skip to content

Commit 895592e

Browse files
authored
Add thread args to ThreadBuffer (#2862)
* [DLMED] add args to ThreadDataLoader Signed-off-by: Nic Ma <[email protected]> * [DLMED] fix flake8 Signed-off-by: Nic Ma <[email protected]>
1 parent 67aa4cf commit 895592e

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

monai/data/thread_buffer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ class ThreadBuffer:
3333
timeout: Time to wait for an item from the buffer, or to wait while the buffer is full when adding items
3434
"""
3535

36-
def __init__(self, src, buffer_size=1, timeout=0.01):
36+
def __init__(self, src, buffer_size: int = 1, timeout: float = 0.01):
3737
self.src = src
3838
self.buffer_size = buffer_size
3939
self.timeout = timeout
40-
self.buffer = Queue(self.buffer_size)
40+
self.buffer: Queue = Queue(self.buffer_size)
4141
self.gen_thread = None
4242
self.is_running = False
4343

@@ -82,11 +82,27 @@ class ThreadDataLoader(DataLoader):
8282
Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will
8383
iterate over data from the loader as expected however the data is generated on a separate thread. Use this class
8484
where a `DataLoader` instance is required and not just an iterable object.
85+
86+
Args:
87+
dataset: input dataset.
88+
buffer_size: number of items to buffer from the data source.
89+
buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items.
90+
num_workers: number of the multi-prcessing workers in PyTorch DataLoader.
91+
8592
"""
8693

87-
def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs):
94+
def __init__(
95+
self,
96+
dataset: Dataset,
97+
buffer_size: int = 1,
98+
buffer_timeout: float = 0.01,
99+
num_workers: int = 0,
100+
**kwargs,
101+
):
88102
super().__init__(dataset, num_workers, **kwargs)
103+
self.buffer_size = buffer_size
104+
self.buffer_timeout = buffer_timeout
89105

90106
def __iter__(self):
91-
buffer = ThreadBuffer(super().__iter__())
107+
buffer = ThreadBuffer(src=super().__iter__(), buffer_size=self.buffer_size, timeout=self.buffer_timeout)
92108
yield from buffer

0 commit comments

Comments
 (0)