Skip to content

Commit 38ecaef

Browse files
authored
Update Range decorator (#2834)
* Change default methods for nn.Module Signed-off-by: Behrooz <[email protected]> * Remove DataLoader Signed-off-by: Behrooz <[email protected]>
1 parent fe559e5 commit 38ecaef

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

monai/utils/nvtx.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,8 @@
1919
from torch.autograd import Function
2020
from torch.nn import Module
2121
from torch.optim import Optimizer
22-
from torch.utils.data import DataLoader, Dataset
22+
from torch.utils.data import Dataset
2323

24-
# from monai.transforms.transform import Transform
2524
from monai.utils import ensure_tuple, optional_import
2625

2726
_nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
@@ -40,7 +39,7 @@ class Range:
4039
methods: (only when used as decorator) the name of a method (or a list of the name of the methods)
4140
to be wrapped by NVTX range.
4241
If None (default), the method(s) will be inferred based on the object's type for various MONAI components,
43-
such as Networks, Losses, Optimizers, Functions, Transforms, Datasets, and Dataloaders.
42+
such as Networks, Losses, Functions, Transforms, and Datasets.
4443
Otherwise, it look up predefined methods: "forward", "__call__", "__next__", "__getitem__"
4544
append_method_name: if append the name of the methods to be decorated to the range's name
4645
If None (default), it appends the method's name only if we are annotating more than one method.
@@ -114,15 +113,13 @@ def range_wrapper(*args, **kwargs):
114113

115114
def _get_method(self, obj: Any) -> tuple:
116115
if isinstance(obj, Module):
117-
method_list = ["forward", "__call__"]
116+
method_list = ["forward"]
118117
elif isinstance(obj, Optimizer):
119118
method_list = ["step"]
120119
elif isinstance(obj, Function):
121120
method_list = ["forward", "backward"]
122121
elif isinstance(obj, Dataset):
123122
method_list = ["__getitem__"]
124-
elif isinstance(obj, DataLoader):
125-
method_list = ["_next_data"]
126123
else:
127124
default_methods = ["forward", "__call__", "__next__", "__getitem__"]
128125
method_list = []

0 commit comments

Comments
 (0)