Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit 6b37a71

Browse files
vreisfacebook-github-bot
authored andcommitted
Performance logging
Summary: Pull Request resolved: #385 Test Plan: . Differential Revision: D19739656 Pulled By: vreis fbshipit-source-id: 347772745f2811bf2947128a23986161395c526d
1 parent aad649e commit 6b37a71

File tree

2 files changed

+51
-0
lines changed

2 files changed

+51
-0
lines changed

classy_vision/hooks/tensorboard_plot_hook.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,21 @@ def on_phase_end(
140140
)
141141
continue
142142

143+
if hasattr(task, "perf_log"):
144+
for perf in task.perf_log:
145+
phase_idx = perf["phase_idx"]
146+
tag = perf["tag"]
147+
for metric_name, metric_value in perf.items():
148+
if metric_name in ["phase_idx", "tag"]:
149+
continue
150+
151+
self.tb_writer.add_scalar(
152+
f"Performance/{tag}/{metric_name}",
153+
metric_value,
154+
global_step=phase_idx,
155+
)
156+
task.perf_log = []
157+
143158
# flush so that the plots aren't lost if training crashes soon after
144159
self.tb_writer.flush()
145160
logging.info(f"Done plotting to Tensorboard")

classy_vision/tasks/classification_task.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import copy
88
import enum
99
import logging
10+
import time
1011
from typing import Any, Dict, List, Optional, Union
1112

1213
import torch
@@ -807,18 +808,53 @@ def get_global_batchsize(self):
807808
def on_start(self, local_variables):
808809
self.run_hooks(local_variables, ClassyHookFunctions.on_start.name)
809810

811+
self.perf_log = []
812+
810813
def on_phase_start(self, local_variables):
814+
self.phase_start_time_total = time.perf_counter()
815+
811816
self.advance_phase()
812817

813818
self.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name)
814819

820+
self.phase_start_time_train = time.perf_counter()
821+
815822
def on_phase_end(self, local_variables):
823+
self.log_phase_end("train")
824+
816825
logging.info("Syncing meters on phase end...")
817826
for meter in self.meters:
818827
meter.sync_state()
819828
logging.info("...meters synced")
820829
barrier()
830+
821831
self.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name)
822832

833+
self.log_phase_end("total")
834+
823835
def on_end(self, local_variables):
824836
self.run_hooks(local_variables, ClassyHookFunctions.on_end.name)
837+
838+
def log_phase_end(self, tag):
839+
if not self.train:
840+
return
841+
842+
assert self.phase_type == "train"
843+
844+
start_time = (
845+
self.phase_start_time_train
846+
if tag == "train"
847+
else self.phase_start_time_total
848+
)
849+
phase_duration = time.perf_counter() - start_time
850+
im_per_sec = (
851+
self.get_global_batchsize() * len(self.dataloaders[self.phase_type])
852+
) / phase_duration
853+
self.perf_log.append(
854+
dict(
855+
tag=tag,
856+
phase_idx=self.train_phase_idx,
857+
epoch_duration=phase_duration,
858+
im_per_sec=im_per_sec,
859+
)
860+
)

0 commit comments

Comments
 (0)