@@ -33,11 +33,11 @@ class ThreadBuffer:
33
33
timeout: Time to wait for an item from the buffer, or to wait while the buffer is full when adding items
34
34
"""
35
35
36
- def __init__ (self , src , buffer_size = 1 , timeout = 0.01 ):
36
+ def __init__ (self , src , buffer_size : int = 1 , timeout : float = 0.01 ):
37
37
self .src = src
38
38
self .buffer_size = buffer_size
39
39
self .timeout = timeout
40
- self .buffer = Queue (self .buffer_size )
40
+ self .buffer : Queue = Queue (self .buffer_size )
41
41
self .gen_thread = None
42
42
self .is_running = False
43
43
@@ -82,11 +82,27 @@ class ThreadDataLoader(DataLoader):
82
82
Subclass of `DataLoader` using a `ThreadBuffer` object to implement `__iter__` method asynchronously. This will
83
83
iterate over data from the loader as expected however the data is generated on a separate thread. Use this class
84
84
where a `DataLoader` instance is required and not just an iterable object.
85
+
86
+ Args:
87
+ dataset: input dataset.
88
+ buffer_size: number of items to buffer from the data source.
89
+ buffer_timeout: time to wait for an item from the buffer, or to wait while the buffer is full when adding items.
90
+ num_workers: number of the multi-prcessing workers in PyTorch DataLoader.
91
+
85
92
"""
86
93
87
- def __init__ (self , dataset : Dataset , num_workers : int = 0 , ** kwargs ):
94
+ def __init__ (
95
+ self ,
96
+ dataset : Dataset ,
97
+ buffer_size : int = 1 ,
98
+ buffer_timeout : float = 0.01 ,
99
+ num_workers : int = 0 ,
100
+ ** kwargs ,
101
+ ):
88
102
super ().__init__ (dataset , num_workers , ** kwargs )
103
+ self .buffer_size = buffer_size
104
+ self .buffer_timeout = buffer_timeout
89
105
90
106
def __iter__ (self ):
91
- buffer = ThreadBuffer (super ().__iter__ ())
107
+ buffer = ThreadBuffer (src = super ().__iter__ (), buffer_size = self . buffer_size , timeout = self . buffer_timeout )
92
108
yield from buffer
0 commit comments