Skip to content

Commit fb2f350

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Add signpost interval for _generate_checkpoint_and_upkeep (#979)
Summary: Pull Request resolved: #979 Reviewed By: anshulverma, JKSenthil Differential Revision: D70585216 fbshipit-source-id: 0adeafe0269027bce2579ded61846e072afc455b
1 parent 714ae04 commit fb2f350

File tree

1 file changed

+71
-55
lines changed

1 file changed

+71
-55
lines changed

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 71 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import logging
1111
import math
1212
from datetime import timedelta
13-
from typing import Any, cast, Iterable, List, Literal, Optional, Union
13+
from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union
1414

1515
import fsspec
1616

@@ -39,6 +39,7 @@
3939
Phase,
4040
)
4141
from torchtnt.utils.distributed import get_world_size, PGWrapper
42+
from torchtnt.utils.event_handlers import log_interval
4243
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
4344

4445
logger: logging.Logger = logging.getLogger(__name__)
@@ -201,70 +202,85 @@ def _generate_checkpoint_and_upkeep(
201202
Returns:
202203
True if checkpoint was successfully saved. False otherwise.
203204
"""
204-
# 1) generate checkpoint name
205-
epoch = _get_epoch(state, unit)
206-
step_mapping = _get_step_phase_mapping(state, unit)
207-
208-
# 1.1) append metric data only if best_checkpoint_config is defined
209-
metric_data: Optional[MetricData] = None
210-
if self._best_checkpoint_config and (
211-
metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit))
212-
):
213-
metric_data = MetricData(
214-
name=none_throws(self._best_checkpoint_config).monitored_metric,
215-
value=metric_value,
216-
)
217-
218-
checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
219-
epoch,
220-
step_mapping,
221-
metric_data,
222-
process_group=self._process_group,
223-
)
224-
225-
# 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints
226-
# since neither best_checkpoint_config nor keep_last_n_checkpoints are supported.
227-
if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
228-
return False
205+
log_interval_metadata: Dict[str, str] = {
206+
"category": "checkpointing",
207+
"active_phase": str(state.active_phase),
208+
"hook": hook,
209+
"epoch": str(_get_epoch(state, unit)),
210+
"step": str(
211+
_get_step_phase_mapping(state, unit).get(
212+
state.active_phase.into_phase(), 0
213+
)
214+
),
215+
}
229216

230-
if hook == "on_train_end":
231-
# 2.1) Make sure that last checkpoint does not already exist
232-
if self._checkpoint_manager.does_checkpoint_exist(
233-
checkpoint_path, self._process_group
217+
with log_interval(
218+
"_generate_checkpoint_and_upkeep", metadata=log_interval_metadata
219+
):
220+
# 1) generate checkpoint name
221+
epoch = _get_epoch(state, unit)
222+
step_mapping = _get_step_phase_mapping(state, unit)
223+
224+
# 1.1) append metric data only if best_checkpoint_config is defined
225+
metric_data: Optional[MetricData] = None
226+
if self._best_checkpoint_config and (
227+
metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit))
234228
):
235-
rank_zero_warn(
236-
"Final checkpoint already exists, skipping.", logger=logger
229+
metric_data = MetricData(
230+
name=none_throws(self._best_checkpoint_config).monitored_metric,
231+
value=metric_value,
237232
)
233+
234+
checkpoint_path = self._checkpoint_manager.generate_checkpoint_path(
235+
epoch,
236+
step_mapping,
237+
metric_data,
238+
process_group=self._process_group,
239+
)
240+
241+
# 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints
242+
# since neither best_checkpoint_config nor keep_last_n_checkpoints are supported.
243+
if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path):
238244
return False
239245

240-
# 2.2) If doing fit without eval checkpointing, only consider training progress when
241-
# checking if last checkpoint exists.
242-
if (
243-
state.entry_point == EntryPoint.FIT
244-
and self._save_every_n_eval_epochs is None
245-
and self._checkpoint_manager._ckpt_paths
246-
and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN]
247-
== cast(TTrainUnit, unit).train_progress.num_steps_completed
246+
if hook == "on_train_end":
247+
# 2.1) Make sure that last checkpoint does not already exist
248+
if self._checkpoint_manager.does_checkpoint_exist(
249+
checkpoint_path, self._process_group
250+
):
251+
rank_zero_warn(
252+
"Final checkpoint already exists, skipping.", logger=logger
253+
)
254+
return False
255+
256+
# 2.2) If doing fit without eval checkpointing, only consider training progress when
257+
# checking if last checkpoint exists.
258+
if (
259+
state.entry_point == EntryPoint.FIT
260+
and self._save_every_n_eval_epochs is None
261+
and self._checkpoint_manager._ckpt_paths
262+
and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN]
263+
== cast(TTrainUnit, unit).train_progress.num_steps_completed
264+
):
265+
rank_zero_info(
266+
"Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.",
267+
logger=logger,
268+
)
269+
return False
270+
271+
# 3) try to save checkpoint
272+
if not self._checkpoint_impl(
273+
state, unit, checkpoint_id=checkpoint_path.path, hook=hook
248274
):
249-
rank_zero_info(
250-
"Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.",
251-
logger=logger,
252-
)
253275
return False
254276

255-
# 3) try to save checkpoint
256-
if not self._checkpoint_impl(
257-
state, unit, checkpoint_id=checkpoint_path.path, hook=hook
258-
):
259-
return False
277+
# 4) track checkpoint and clean up surplus if needed
278+
self._checkpoint_manager.append_checkpoint(checkpoint_path)
260279

261-
# 4) track checkpoint and clean up surplus if needed
262-
self._checkpoint_manager.append_checkpoint(checkpoint_path)
280+
# 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully
281+
unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path)
263282

264-
# 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully
265-
unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path)
266-
267-
return True
283+
return True
268284

269285
def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
270286
"""

0 commit comments

Comments
 (0)