diff --git a/classy_vision/hooks/tensorboard_plot_hook.py b/classy_vision/hooks/tensorboard_plot_hook.py index 284a52568c..31487d3832 100644 --- a/classy_vision/hooks/tensorboard_plot_hook.py +++ b/classy_vision/hooks/tensorboard_plot_hook.py @@ -140,6 +140,20 @@ def on_phase_end( ) continue + if hasattr(task, "perf_log"): + for perf in task.perf_log: + phase_idx = perf["phase_idx"] + tag = perf["tag"] + for metric_name, metric_value in perf.items(): + if metric_name in ["phase_idx", "tag"]: + continue + + self.tb_writer.add_scalar( + f"Performance/{tag}/{metric_name}", + metric_value, + global_step=phase_idx, + ) + # flush so that the plots aren't lost if training crashes soon after self.tb_writer.flush() logging.info(f"Done plotting to Tensorboard") diff --git a/classy_vision/tasks/classification_task.py b/classy_vision/tasks/classification_task.py index dd9c066ca1..588015f0d1 100644 --- a/classy_vision/tasks/classification_task.py +++ b/classy_vision/tasks/classification_task.py @@ -7,6 +7,7 @@ import copy import enum import logging +import time from typing import Any, Dict, List, Optional, Union import torch @@ -93,6 +94,7 @@ class ClassificationTask(ClassyTask): by the optimizer :var data_iterator: Iterator which can be used to obtain batches :var losses: Loss curve + :var perf_log: list of training speed measurements, to be logged """ @@ -122,6 +124,7 @@ def __init__(self): BroadcastBuffersMode.DISABLED ) self.amp_opt_level = None + self.perf_log = [] def set_checkpoint(self, checkpoint): """Sets checkpoint on task. @@ -809,17 +812,49 @@ def on_start(self, local_variables): self.run_hooks(local_variables, ClassyHookFunctions.on_start.name) def on_phase_start(self, local_variables): + self.phase_start_time_total = time.perf_counter() + self.advance_phase() self.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name) + self.phase_start_time_train = time.perf_counter() + def on_phase_end(self, local_variables): + self.log_phase_end("train") + logging.info("Syncing meters on phase end...") for meter in self.meters: meter.sync_state() logging.info("...meters synced") barrier() + self.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name) + self.perf_log = [] + + self.log_phase_end("total") def on_end(self, local_variables): self.run_hooks(local_variables, ClassyHookFunctions.on_end.name) + + def log_phase_end(self, tag): + if not self.train: + return + + start_time = ( + self.phase_start_time_train + if tag == "train" + else self.phase_start_time_total + ) + phase_duration = time.perf_counter() - start_time + im_per_sec = ( + self.get_global_batchsize() * self.num_batches_per_phase + ) / phase_duration + self.perf_log.append( + { + "tag": tag, + "phase_idx": self.train_phase_idx, + "epoch_duration": phase_duration, + "im_per_sec": im_per_sec, + } + )