Skip to content

Commit 713490b

Browse files
authored
Garbage Collector Handler (#1940)
* Implement garbage collector handler Signed-off-by: Behrooz <[email protected]> * Make trigger_event lower case Signed-off-by: Behrooz <[email protected]> * Add unittest for garbage collector Signed-off-by: Behrooz <[email protected]> * Update docs Signed-off-by: Behrooz <[email protected]> * Exclude from min test Signed-off-by: Behrooz <[email protected]> * Fix a typo Signed-off-by: Behrooz <[email protected]> * Fix a bug Signed-off-by: Behrooz <[email protected]>
1 parent a6b8f2a commit 713490b

File tree

5 files changed

+164
-0
lines changed

5 files changed

+164
-0
lines changed

docs/source/handlers.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,8 @@ EarlyStop handler
115115
-----------------
116116
.. autoclass:: EarlyStopHandler
117117
:members:
118+
119+
GarbageCollector handler
120+
------------------------
121+
.. autoclass:: GarbageCollector
122+
:members:

monai/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from .classification_saver import ClassificationSaver
1515
from .confusion_matrix import ConfusionMatrix
1616
from .earlystop_handler import EarlyStopHandler
17+
from .garbage_collector import GarbageCollector
1718
from .hausdorff_distance import HausdorffDistance
1819
from .iteration_metric import IterationMetric
1920
from .lr_schedule_handler import LrScheduleHandler

monai/handlers/garbage_collector.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import gc
13+
from typing import TYPE_CHECKING
14+
15+
from monai.utils import exact_version, optional_import
16+
17+
if TYPE_CHECKING:
18+
from ignite.engine import Engine, Events
19+
else:
20+
Engine, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Engine")
21+
Events, _ = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
22+
23+
24+
class GarbageCollector:
25+
"""
26+
Run garbage collector after each epoch
27+
28+
Args:
29+
trigger_event: the event that trigger a call to this handler.
30+
- "epoch", after completion of each epoch (equivalent of ignite.engine.Events.EPOCH_COMPLETED)
31+
- "iteration", after completion of each iteration (equivalent of ignite.engine.Events.ITERATION_COMPLETED)
32+
- any ignite built-in event from ignite.engine.Events.
33+
Defaults to "epoch".
34+
log_level: log level (integer) for some garbage collection information as below. Defaults to 10 (DEBUG).
35+
- 50 (CRITICAL)
36+
- 40 (ERROR)
37+
- 30 (WARNING)
38+
- 20 (INFO)
39+
- 10 (DEBUG)
40+
- 0 (NOTSET)
41+
"""
42+
43+
def __init__(self, trigger_event: str = "epoch", log_level: int = 10):
44+
if isinstance(trigger_event, Events):
45+
self.trigger_event = trigger_event
46+
elif trigger_event.lower() == "epoch":
47+
self.trigger_event = Events.EPOCH_COMPLETED
48+
elif trigger_event.lower() == "iteration":
49+
self.trigger_event = Events.ITERATION_COMPLETED
50+
else:
51+
raise ValueError(
52+
f"'trigger_event' should be either epoch, iteration, or an ignite built-in event from"
53+
f" ignite.engine.Events, '{trigger_event}' was given."
54+
)
55+
56+
self.log_level = log_level
57+
58+
def attach(self, engine: Engine) -> None:
59+
if not engine.has_event_handler(self, self.trigger_event):
60+
engine.add_event_handler(self.trigger_event, self)
61+
62+
def __call__(self, engine: Engine) -> None:
63+
"""
64+
This method calls python garbage collector.
65+
66+
Args:
67+
engine: Ignite Engine, it should be either a trainer or validator.
68+
"""
69+
# get count before garbage collection
70+
pre_count = gc.get_count()
71+
# fits call to garbage collector
72+
gc.collect()
73+
# second call to garbage collector
74+
unreachable = gc.collect()
75+
# get count after garbage collection
76+
after_count = gc.get_count()
77+
engine.logger.log(
78+
self.log_level,
79+
f"Garbage Count: [before: {pre_count}] -> [after: {after_count}] (unreachable : {unreachable})",
80+
)

tests/min_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def run_testsuit():
4242
"test_handler_confusion_matrix",
4343
"test_handler_confusion_matrix_dist",
4444
"test_handler_hausdorff_distance",
45+
"test_handler_garbage_collector",
4546
"test_handler_mean_dice",
4647
"test_handler_prob_map_producer",
4748
"test_handler_rocauc",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
# Copyright 2020 - 2021 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import gc
13+
import unittest
14+
from unittest import skipUnless
15+
16+
import torch
17+
from ignite.engine import Engine
18+
from parameterized import parameterized
19+
20+
from monai.data import Dataset
21+
from monai.handlers import GarbageCollector
22+
from monai.utils import exact_version, optional_import
23+
24+
Events, has_ignite = optional_import("ignite.engine", "0.4.4", exact_version, "Events")
25+
26+
27+
TEST_CASE_0 = [[0, 1, 2], "epoch"]
28+
29+
TEST_CASE_1 = [[0, 1, 2], "iteration"]
30+
31+
TEST_CASE_2 = [[0, 1, 2], Events.EPOCH_COMPLETED]
32+
33+
34+
class TestHandlerGarbageCollector(unittest.TestCase):
35+
@skipUnless(has_ignite, "Requires ignite")
36+
@parameterized.expand(
37+
[
38+
TEST_CASE_0,
39+
TEST_CASE_1,
40+
TEST_CASE_2,
41+
]
42+
)
43+
def test_content(self, data, trigger_event):
44+
# set up engine
45+
gb_count_dict = {}
46+
47+
def _train_func(engine, batch):
48+
# store garbage collection counts
49+
if trigger_event == Events.EPOCH_COMPLETED or trigger_event.lower() == "epoch":
50+
if engine.state.iteration % engine.state.epoch_length == 1:
51+
gb_count_dict[engine.state.epoch] = gc.get_count()
52+
elif trigger_event.lower() == "iteration":
53+
gb_count_dict[engine.state.iteration] = gc.get_count()
54+
55+
engine = Engine(_train_func)
56+
57+
# set up testing handler
58+
dataset = Dataset(data, transform=None)
59+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1)
60+
GarbageCollector(trigger_event=trigger_event, log_level=30).attach(engine)
61+
62+
engine.run(data_loader, max_epochs=5)
63+
print(gb_count_dict)
64+
65+
first_count = 0
66+
for epoch, gb_count in gb_count_dict.items():
67+
# At least one zero-generation object
68+
self.assertGreater(gb_count[0], 0)
69+
if epoch == 1:
70+
first_count = gb_count[0]
71+
else:
72+
# The should be less number of collected objects in the next calls.
73+
self.assertLess(gb_count[0], first_count)
74+
75+
76+
if __name__ == "__main__":
77+
unittest.main()

0 commit comments

Comments
 (0)