Skip to content

Commit 7e05d91

Browse files
Feature fixed weights hist handler issue 2328 (#2523)
* Remove unnecessary code in BaseOutputHandler Closes #2438 * Add ReduceLROnPlateauScheduler Closes #1754 * Fix indentation issue * Fix another indentation issue * Fix PEP8 related issues * Fix other PEP8 related issues * Fix hopefully the last PEP8 related issue * Fix hopefully the last PEP8 related issue * Remove ReduceLROnPlateau's specific params and add link to it Also fix bug in min_lr check * Fix state_dict bug and add a test * Update docs * Add doctest and fix typo * Add feature FixedWeightsHistHandler Closes #2328 * Move FixedWeightsHistHandler's job to WeightsHistHandler Closes #2328 * Enable whitelist to be callable * autopep8 fix * Refactor constructor * Change whitelist to be List[str] * Add whitelist callable type * Fix bug in MNIST tensorboard example Co-authored-by: vfdev <[email protected]> Co-authored-by: sadra-barikbin <[email protected]>
1 parent a359694 commit 7e05d91

File tree

4 files changed

+105
-21
lines changed

4 files changed

+105
-21
lines changed

examples/contrib/mnist/mnist_with_tensorboard_logger.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,11 +135,20 @@ def compute_metrics(engine):
135135

136136
tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
137137

138-
tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
138+
tb_logger.attach(
139+
trainer,
140+
log_handler=WeightsHistHandler(
141+
model,
142+
whitelist=[
143+
"conv",
144+
],
145+
),
146+
event_name=Events.ITERATION_COMPLETED(every=100),
147+
)
139148

140149
tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
141150

142-
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED(every=100))
151+
tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.ITERATION_COMPLETED(every=100))
143152

144153
def score_function(engine):
145154
return engine.state.metrics["accuracy"]

ignite/contrib/handlers/tensorboard_logger.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,12 @@ class WeightsHistHandler(BaseWeightsHistHandler):
404404
Args:
405405
model: model to log weights
406406
tag: common title for all produced plots. For example, "generator"
407+
whitelist: specific weights to log. Should be list of model's submodules
408+
or parameters names, or a callable which gets weight along with its name
409+
and determines if it should be logged. Names should be fully-qualified.
410+
For more information please refer to `PyTorch docs
411+
<https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.get_submodule>`_.
412+
If not given, all of model's weights are logged.
407413
408414
Examples:
409415
.. code-block:: python
@@ -419,20 +425,75 @@ class WeightsHistHandler(BaseWeightsHistHandler):
419425
event_name=Events.ITERATION_COMPLETED,
420426
log_handler=WeightsHistHandler(model)
421427
)
428+
429+
.. code-block:: python
430+
431+
from ignite.contrib.handlers.tensorboard_logger import *
432+
433+
# Create a logger
434+
tb_logger = TensorboardLogger(log_dir="experiments/tb_logs")
435+
436+
# Log weights of `fc` layer
437+
weights = ['fc']
438+
439+
# Attach the logger to the trainer to log weights norm after each iteration
440+
tb_logger.attach(
441+
trainer,
442+
event_name=Events.ITERATION_COMPLETED,
443+
log_handler=WeightsHistHandler(model, whitelist=weights)
444+
)
445+
446+
.. code-block:: python
447+
448+
from ignite.contrib.handlers.tensorboard_logger import *
449+
450+
# Create a logger
451+
tb_logger = TensorboardLogger(log_dir="experiments/tb_logs")
452+
453+
# Log weights which name include 'conv'.
454+
weight_selector = lambda name, p: 'conv' in name
455+
456+
# Attach the logger to the trainer to log weights norm after each iteration
457+
tb_logger.attach(
458+
trainer,
459+
event_name=Events.ITERATION_COMPLETED,
460+
log_handler=WeightsHistHandler(model, whitelist=weight_selector)
461+
)
422462
"""
423463

424-
def __init__(self, model: nn.Module, tag: Optional[str] = None):
464+
def __init__(
465+
self,
466+
model: nn.Module,
467+
tag: Optional[str] = None,
468+
whitelist: Optional[Union[List[str], Callable[[str, nn.Parameter], bool]]] = None,
469+
):
425470
super(WeightsHistHandler, self).__init__(model, tag=tag)
426471

472+
weights = {}
473+
if whitelist is None:
474+
475+
weights = dict(model.named_parameters())
476+
elif callable(whitelist):
477+
478+
for n, p in model.named_parameters():
479+
if whitelist(n, p):
480+
weights[n] = p
481+
else:
482+
483+
for n, p in model.named_parameters():
484+
for item in whitelist:
485+
if n.startswith(item):
486+
weights[n] = p
487+
488+
self.weights = weights.items()
489+
427490
def __call__(self, engine: Engine, logger: TensorboardLogger, event_name: Union[str, Events]) -> None:
428491
if not isinstance(logger, TensorboardLogger):
429492
raise RuntimeError("Handler 'WeightsHistHandler' works only with TensorboardLogger")
430493

431494
global_step = engine.state.get_event_attrib_value(event_name)
432495
tag_prefix = f"{self.tag}/" if self.tag else ""
433-
for name, p in self.model.named_parameters():
434-
if p.grad is None:
435-
continue
496+
for name, p in self.weights:
436497

437498
name = name.replace(".", "/")
438499
logger.writer.add_histogram(

tests/ignite/contrib/handlers/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self):
2828
self.fc2.weight.data.fill_(1.0)
2929
self.fc2.bias.data.fill_(1.0)
3030

31-
def get_dummy_model(with_grads=True, with_frozen_layer=False):
31+
def get_dummy_model(with_grads=True, with_frozen_layer=False, with_buffer=False):
3232
model = DummyModel()
3333
if with_grads:
3434
model.fc2.weight.grad = torch.zeros_like(model.fc2.weight)
@@ -41,6 +41,9 @@ def get_dummy_model(with_grads=True, with_frozen_layer=False):
4141
if with_frozen_layer:
4242
for param in model.fc1.parameters():
4343
param.requires_grad = False
44+
45+
if with_buffer:
46+
model.register_buffer("buffer1", torch.ones(1))
4447
return model
4548

4649
return get_dummy_model

tests/ignite/contrib/handlers/test_tensorboard_logger.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -380,36 +380,47 @@ def _test(tag=None):
380380
_test(tag="tag")
381381

382382

383-
def test_weights_hist_handler_frozen_layers(dummy_model_factory):
383+
def test_weights_hist_handler_whitelist(dummy_model_factory):
384+
model = dummy_model_factory()
384385

385-
model = dummy_model_factory(with_grads=True, with_frozen_layer=True)
386-
387-
wrapper = WeightsHistHandler(model)
386+
wrapper = WeightsHistHandler(model, whitelist=["fc2.weight"])
388387
mock_logger = MagicMock(spec=TensorboardLogger)
389388
mock_logger.writer = MagicMock()
390389

391390
mock_engine = MagicMock()
392391
mock_engine.state = State()
393392
mock_engine.state.epoch = 5
394393

394+
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
395+
mock_logger.writer.add_histogram.assert_called_once_with(tag="weights/fc2/weight", values=ANY, global_step=5)
396+
mock_logger.writer.reset_mock()
397+
398+
wrapper = WeightsHistHandler(model, tag="model", whitelist=["fc1"])
395399
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
396400

397401
mock_logger.writer.add_histogram.assert_has_calls(
398402
[
399-
call(tag="weights/fc2/weight", values=ANY, global_step=5),
400-
call(tag="weights/fc2/bias", values=ANY, global_step=5),
403+
call(tag="model/weights/fc1/weight", values=ANY, global_step=5),
404+
call(tag="model/weights/fc1/bias", values=ANY, global_step=5),
401405
],
402406
any_order=True,
403407
)
408+
assert mock_logger.writer.add_histogram.call_count == 2
409+
mock_logger.writer.reset_mock()
404410

405-
with pytest.raises(AssertionError):
406-
mock_logger.writer.add_histogram.assert_has_calls(
407-
[
408-
call(tag="weights/fc1/weight", values=ANY, global_step=5),
409-
call(tag="weights/fc1/bias", values=ANY, global_step=5),
410-
],
411-
any_order=True,
412-
)
411+
def weight_selector(n, _):
412+
return "bias" in n
413+
414+
wrapper = WeightsHistHandler(model, tag="model", whitelist=weight_selector)
415+
wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
416+
417+
mock_logger.writer.add_histogram.assert_has_calls(
418+
[
419+
call(tag="model/weights/fc1/bias", values=ANY, global_step=5),
420+
call(tag="model/weights/fc2/bias", values=ANY, global_step=5),
421+
],
422+
any_order=True,
423+
)
413424
assert mock_logger.writer.add_histogram.call_count == 2
414425

415426

0 commit comments

Comments
 (0)