Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Commit ba2db75

Browse files
vreisfacebook-github-bot
authored andcommitted
Performance logging (#385)
Summary: This changes ClassificationTask to compute some high-level performance numbers (img/sec) and plot them in Tensorboard. This is useful for comparing performance optimizations since we now get a "blessed" performance number. Also, this was done in a way that's comparable to NVidia's benchmarks (e.g. https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch/performance), so we know how well we're doing compared to other implementations. In terms of implementation, I could have made a hook instead, but decided against it for two reasons: (1) it would introduce dependencies between hooks; (2) we want to control precisely when the timing measurements are taken; Pull Request resolved: #385 Test Plan: ./classy_train.py --config configs/template_config.json Reviewed By: mannatsingh Differential Revision: D19739656 Pulled By: vreis fbshipit-source-id: 1d2c34b93f58ed7218674d4415925ad7189e4359
1 parent 7c25113 commit ba2db75

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

classy_vision/hooks/tensorboard_plot_hook.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,20 @@ def on_phase_end(
140140
)
141141
continue
142142

143+
if hasattr(task, "perf_log"):
144+
for perf in task.perf_log:
145+
phase_idx = perf["phase_idx"]
146+
tag = perf["tag"]
147+
for metric_name, metric_value in perf.items():
148+
if metric_name in ["phase_idx", "tag"]:
149+
continue
150+
151+
self.tb_writer.add_scalar(
152+
f"Speed/{tag}/{metric_name}",
153+
metric_value,
154+
global_step=phase_idx,
155+
)
156+
143157
# flush so that the plots aren't lost if training crashes soon after
144158
self.tb_writer.flush()
145159
logging.info(f"Done plotting to Tensorboard")

classy_vision/tasks/classification_task.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import copy
88
import enum
99
import logging
10+
import time
1011
from typing import Any, Dict, List, Optional, Union
1112

1213
import torch
@@ -93,6 +94,7 @@ class ClassificationTask(ClassyTask):
9394
by the optimizer
9495
:var data_iterator: Iterator which can be used to obtain batches
9596
:var losses: Loss curve
97+
:var perf_log: list of training speed measurements, to be logged
9698
9799
"""
98100

@@ -122,6 +124,7 @@ def __init__(self):
122124
BroadcastBuffersMode.DISABLED
123125
)
124126
self.amp_opt_level = None
127+
self.perf_log = []
125128

126129
def set_checkpoint(self, checkpoint):
127130
"""Sets checkpoint on task.
@@ -809,17 +812,49 @@ def on_start(self, local_variables):
809812
self.run_hooks(local_variables, ClassyHookFunctions.on_start.name)
810813

811814
def on_phase_start(self, local_variables):
815+
self.phase_start_time_total = time.perf_counter()
816+
812817
self.advance_phase()
813818

814819
self.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name)
815820

821+
self.phase_start_time_train = time.perf_counter()
822+
816823
def on_phase_end(self, local_variables):
824+
self.log_phase_end("train")
825+
817826
logging.info("Syncing meters on phase end...")
818827
for meter in self.meters:
819828
meter.sync_state()
820829
logging.info("...meters synced")
821830
barrier()
831+
822832
self.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name)
833+
self.perf_log = []
834+
835+
self.log_phase_end("total")
823836

824837
def on_end(self, local_variables):
825838
self.run_hooks(local_variables, ClassyHookFunctions.on_end.name)
839+
840+
def log_phase_end(self, tag):
841+
if not self.train:
842+
return
843+
844+
start_time = (
845+
self.phase_start_time_train
846+
if tag == "train"
847+
else self.phase_start_time_total
848+
)
849+
phase_duration = time.perf_counter() - start_time
850+
im_per_sec = (
851+
self.get_global_batchsize() * self.num_batches_per_phase
852+
) / phase_duration
853+
self.perf_log.append(
854+
{
855+
"tag": tag,
856+
"phase_idx": self.train_phase_idx,
857+
"epoch_duration": phase_duration,
858+
"im_per_sec": im_per_sec,
859+
}
860+
)

0 commit comments

Comments
 (0)