Skip to content

Commit 5bc1702

Browse files
JKSenthilfacebook-github-bot
authored andcommitted
move out zero grad logic into separate function (#969)
Summary: Pull Request resolved: #969 # Context Currently it isn't possible to log gradients from AutoUnit as they are zeroed out before `on_train_step_end()` is reached. # This Diff Moves out the zeroed grad from the `_update_weights` and into it's own function. Can be overridden, ie ``` class MyAutoUnit(AutoUnit): ... def zero_grad(self) -> self.logger.log(self.module.grad) super().zero_grad() ``` to log the gradients prior to zeroing them out Reviewed By: galrotem, diego-urgell Differential Revision: D68983117 fbshipit-source-id: 744b72c5634d8b6979ef1145fc3254ddde93d743
1 parent 93347d9 commit 5bc1702

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

torchtnt/framework/auto_unit.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,22 @@ def step_lr_scheduler(self) -> None:
829829
"""
830830
none_throws(self.lr_scheduler).step()
831831

832+
def zero_grad(self) -> None:
833+
"""
834+
Zeroes the gradients of the module's parameters. Override this if you need to log the gradients before zeroing them.
835+
836+
Example of overriding:
837+
class CustomAutoUnit(MyAutoUnit):
838+
...
839+
840+
def zero_grad(self):
841+
# log before zeroing gradients
842+
super().zero_grad()
843+
"""
844+
845+
optimizer = none_throws(self.optimizer)
846+
optimizer.zero_grad(set_to_none=True)
847+
832848
def _update_weights(self, state: State) -> Optional[torch.Tensor]:
833849
"""
834850
Updates weights of the module, handles clip gradient norm, etc.
@@ -892,7 +908,7 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
892908
with get_timing_context(
893909
state, f"{self.__class__.__name__}.optimizer_zero_grad"
894910
):
895-
optimizer.zero_grad(set_to_none=True)
911+
self.zero_grad()
896912

897913
if self.step_lr_interval == "step":
898914
self._update_lr_and_swa(state, self.train_progress.num_steps_completed)

0 commit comments

Comments
 (0)