Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit d2016ef

Browse files
mannatsinghfacebook-github-bot
authored andcommitted
Fix bug in profiler and task init code
Differential Revision: D20342756 Profiler update fbshipit-source-id: 3e4dcb31f1cbf824fca3b74467ed8940f472029c
1 parent 6f2efc6 commit d2016ef

File tree

4 files changed

+85
-19
lines changed

4 files changed

+85
-19
lines changed

classy_vision/generic/profiler.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,12 @@
1010

1111
import torch
1212
import torch.nn as nn
13-
from classy_vision.generic.util import get_model_dummy_input, is_leaf, is_on_gpu
13+
from classy_vision.generic.util import (
14+
get_model_dummy_input,
15+
is_leaf,
16+
is_on_gpu,
17+
train_mode,
18+
)
1419
from torch.cuda import cudart
1520

1621

@@ -24,7 +29,6 @@ def profile(
2429
"""
2530
Performs CPU or GPU profiling of the specified model on the specified input.
2631
"""
27-
2832
# assertions:
2933
if use_nvprof:
3034
raise NotImplementedError
@@ -41,18 +45,19 @@ def profile(
4145
batchsize=batchsize_per_replica,
4246
non_blocking=False,
4347
)
44-
# perform profiling:
45-
with torch.no_grad():
46-
model(input) # warm up CUDA memory allocator and profiler
47-
if use_nvprof: # nvprof profiling (TODO: Can we infer this?)
48-
cudart().cudaProfilerStart()
49-
model(input)
50-
cudart().cudaProfilerStop()
51-
exit() # exit gracefully
52-
else: # regular profiling
53-
with torch.autograd.profiler.profile(use_cuda=True) as profiler:
48+
# perform profiling in eval mode
49+
with train_mode(model, False):
50+
with torch.no_grad():
51+
model(input) # warm up CUDA memory allocator and profiler
52+
if use_nvprof: # nvprof profiling (TODO: Can we infer this?)
53+
cudart().cudaProfilerStart()
5454
model(input)
55-
return profiler
55+
cudart().cudaProfilerStop()
56+
exit() # exit gracefully
57+
else: # regular profiling
58+
with torch.autograd.profiler.profile(use_cuda=True) as profiler:
59+
model(input)
60+
return profiler
5661

5762

5863
def _get_batchsize_per_replica(x):
@@ -376,7 +381,6 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None):
376381
"""
377382
Compute the complexity of a forward pass.
378383
"""
379-
380384
# assertions, input, and upvalue in which we will perform the count:
381385
assert isinstance(model, nn.Module)
382386
if not isinstance(input_shape, abc.Sequence):
@@ -387,7 +391,10 @@ def compute_complexity(model, compute_fn, input_shape, input_key=None):
387391
# measure FLOPs:
388392
modify_forward(model, compute_list, compute_fn)
389393
try:
390-
model.forward(input)
394+
# compute complexity in eval mode
395+
with train_mode(model, False):
396+
with torch.no_grad():
397+
model.forward(input)
391398
except NotImplementedError as err:
392399
raise err
393400
finally:

classy_vision/generic/util.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,3 +774,18 @@ def get_model_dummy_input(
774774
if input_key:
775775
input = {input_key: input}
776776
return input
777+
778+
779+
@contextlib.contextmanager
780+
def train_mode(model: nn.Module, train_mode: bool):
781+
"""Context manager which sets the train mode of a model. After returning, it
782+
restores the state of every module inside the model individually."""
783+
train_modes = {}
784+
for name, module in model.named_modules():
785+
train_modes[name] = module.training
786+
try:
787+
model.train(train_mode)
788+
yield
789+
finally:
790+
for name, module in model.named_modules():
791+
module.training = train_modes[name]

classy_vision/tasks/classification_task.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
recursive_copy_to_gpu,
2424
update_classy_state,
2525
)
26-
from classy_vision.hooks import ClassyHookFunctions
2726
from classy_vision.losses import ClassyLoss, build_loss
2827
from classy_vision.meters import build_meters
2928
from classy_vision.models import ClassyModel, build_model
@@ -305,9 +304,6 @@ def from_config(cls, config: Dict[str, Any]) -> "ClassificationTask":
305304
amp_opt_level = config.get("amp_opt_level")
306305
meters = build_meters(config.get("meters", {}))
307306
model = build_model(config["model"])
308-
# put model in eval mode in case any hooks modify model states, it'll
309-
# be reset to train mode before training
310-
model.eval()
311307
optimizer = build_optimizer(optimizer_config)
312308

313309
task = (

test/generic_util_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import copy
78
import shutil
89
import tempfile
910
import unittest
@@ -14,6 +15,7 @@
1415

1516
import classy_vision.generic.util as util
1617
import torch
18+
import torch.nn as nn
1719
from classy_vision.generic.util import (
1820
CHECKPOINT_FILE,
1921
load_checkpoint,
@@ -368,6 +370,52 @@ def test_get_model_dummy_input(self):
368370
)
369371
self.assertEqual(result.size(), tuple([batchsize] + input_shape))
370372

373+
def _compare_model_train_mode(self, model_1, model_2):
374+
for name_1, module_1 in model_1.named_modules():
375+
found = False
376+
for name_2, module_2 in model_2.named_modules():
377+
if name_1 == name_2:
378+
found = True
379+
if module_1.training != module_2.training:
380+
return False
381+
if not found:
382+
return False
383+
return True
384+
385+
def _check_model_train_mode(self, model, expected_mode):
386+
for module in model.modules():
387+
if module.training != expected_mode:
388+
return False
389+
return True
390+
391+
def test_train_mode(self):
392+
class TestModel(nn.Module):
393+
def __init__(self):
394+
super().__init__()
395+
self.linear = nn.Linear(1, 2)
396+
self.dropout = nn.Dropout()
397+
self.seq = nn.Sequential(
398+
nn.ReLU(), nn.Conv2d(1, 2, 3), nn.BatchNorm2d(1, 2)
399+
)
400+
401+
test_model = TestModel()
402+
for train in [True, False]:
403+
test_model.train(train)
404+
405+
# flip some of the modes
406+
test_model.dropout.train(not train)
407+
test_model.seq[1].train(not train)
408+
409+
orig_model = copy.deepcopy(test_model)
410+
for context_train in [True, False]:
411+
with util.train_mode(test_model, context_train):
412+
self._check_model_train_mode(test_model, context_train)
413+
# the modes should be different inside the context manager
414+
self.assertFalse(
415+
self._compare_model_train_mode(orig_model, test_model)
416+
)
417+
self.assertTrue(self._compare_model_train_mode(orig_model, test_model))
418+
371419

372420
class TestUpdateStateFunctions(unittest.TestCase):
373421
def _compare_states(self, state_1, state_2, check_heads=True):

0 commit comments

Comments
 (0)