Skip to content
This repository was archived by the owner on Jul 1, 2024. It is now read-only.

Performance logging #385

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions classy_vision/hooks/tensorboard_plot_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ def on_phase_end(
)
continue

if hasattr(task, "perf_log"):
for perf in task.perf_log:
phase_idx = perf["phase_idx"]
tag = perf["tag"]
for metric_name, metric_value in perf.items():
if metric_name in ["phase_idx", "tag"]:
continue

self.tb_writer.add_scalar(
f"Performance/{tag}/{metric_name}",
metric_value,
global_step=phase_idx,
)

# flush so that the plots aren't lost if training crashes soon after
self.tb_writer.flush()
logging.info(f"Done plotting to Tensorboard")
35 changes: 35 additions & 0 deletions classy_vision/tasks/classification_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import copy
import enum
import logging
import time
from typing import Any, Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -93,6 +94,7 @@ class ClassificationTask(ClassyTask):
by the optimizer
:var data_iterator: Iterator which can be used to obtain batches
:var losses: Loss curve
:var perf_log: list of training speed measurements, to be logged

"""

Expand Down Expand Up @@ -122,6 +124,7 @@ def __init__(self):
BroadcastBuffersMode.DISABLED
)
self.amp_opt_level = None
self.perf_log = []

def set_checkpoint(self, checkpoint):
"""Sets checkpoint on task.
Expand Down Expand Up @@ -809,17 +812,49 @@ def on_start(self, local_variables):
self.run_hooks(local_variables, ClassyHookFunctions.on_start.name)

def on_phase_start(self, local_variables):
self.phase_start_time_total = time.perf_counter()

self.advance_phase()

self.run_hooks(local_variables, ClassyHookFunctions.on_phase_start.name)

self.phase_start_time_train = time.perf_counter()

def on_phase_end(self, local_variables):
self.log_phase_end("train")

logging.info("Syncing meters on phase end...")
for meter in self.meters:
meter.sync_state()
logging.info("...meters synced")
barrier()

self.run_hooks(local_variables, ClassyHookFunctions.on_phase_end.name)
self.perf_log = []

self.log_phase_end("total")

def on_end(self, local_variables):
self.run_hooks(local_variables, ClassyHookFunctions.on_end.name)

def log_phase_end(self, tag):
if not self.train:
return

start_time = (
self.phase_start_time_train
if tag == "train"
else self.phase_start_time_total
)
phase_duration = time.perf_counter() - start_time
im_per_sec = (
self.get_global_batchsize() * self.num_batches_per_phase
) / phase_duration
self.perf_log.append(
{
"tag": tag,
"phase_idx": self.train_phase_idx,
"epoch_duration": phase_duration,
"im_per_sec": im_per_sec,
}
)