diff --git a/monai/data/thread_buffer.py b/monai/data/thread_buffer.py index 8ea71e3555..2901335bd5 100644 --- a/monai/data/thread_buffer.py +++ b/monai/data/thread_buffer.py @@ -87,8 +87,6 @@ class ThreadDataLoader(DataLoader): def __init__(self, dataset: Dataset, num_workers: int = 0, **kwargs): super().__init__(dataset, num_workers, **kwargs) - # ThreadBuffer will use the inherited __iter__ instead of the one defined below - self.buffer = ThreadBuffer(super().__iter__()) - def __iter__(self): - yield from self.buffer + buffer = ThreadBuffer(super().__iter__()) + yield from buffer diff --git a/tests/test_thread_buffer.py b/tests/test_thread_buffer.py index 1b3ebb910d..507b6909be 100644 --- a/tests/test_thread_buffer.py +++ b/tests/test_thread_buffer.py @@ -48,6 +48,8 @@ def test_dataloader(self): for d in dataloader: self.assertEqual(d["image"][0], "spleen_19.nii.gz") self.assertEqual(d["image"][1], "spleen_31.nii.gz") + + for d in dataloader: self.assertEqual(d["label"][0], "spleen_label_19.nii.gz") self.assertEqual(d["label"][1], "spleen_label_31.nii.gz")