|
10 | 10 | import logging
|
11 | 11 | import math
|
12 | 12 | from datetime import timedelta
|
13 |
| -from typing import Any, cast, Iterable, List, Literal, Optional, Union |
| 13 | +from typing import Any, cast, Dict, Iterable, List, Literal, Optional, Union |
14 | 14 |
|
15 | 15 | import fsspec
|
16 | 16 |
|
|
39 | 39 | Phase,
|
40 | 40 | )
|
41 | 41 | from torchtnt.utils.distributed import get_world_size, PGWrapper
|
| 42 | +from torchtnt.utils.event_handlers import log_interval |
42 | 43 | from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn
|
43 | 44 |
|
44 | 45 | logger: logging.Logger = logging.getLogger(__name__)
|
@@ -201,70 +202,85 @@ def _generate_checkpoint_and_upkeep(
|
201 | 202 | Returns:
|
202 | 203 | True if checkpoint was successfully saved. False otherwise.
|
203 | 204 | """
|
204 |
| - # 1) generate checkpoint name |
205 |
| - epoch = _get_epoch(state, unit) |
206 |
| - step_mapping = _get_step_phase_mapping(state, unit) |
207 |
| - |
208 |
| - # 1.1) append metric data only if best_checkpoint_config is defined |
209 |
| - metric_data: Optional[MetricData] = None |
210 |
| - if self._best_checkpoint_config and ( |
211 |
| - metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit)) |
212 |
| - ): |
213 |
| - metric_data = MetricData( |
214 |
| - name=none_throws(self._best_checkpoint_config).monitored_metric, |
215 |
| - value=metric_value, |
216 |
| - ) |
217 |
| - |
218 |
| - checkpoint_path = self._checkpoint_manager.generate_checkpoint_path( |
219 |
| - epoch, |
220 |
| - step_mapping, |
221 |
| - metric_data, |
222 |
| - process_group=self._process_group, |
223 |
| - ) |
224 |
| - |
225 |
| - # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints |
226 |
| - # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported. |
227 |
| - if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): |
228 |
| - return False |
| 205 | + log_interval_metadata: Dict[str, str] = { |
| 206 | + "category": "checkpointing", |
| 207 | + "active_phase": str(state.active_phase), |
| 208 | + "hook": hook, |
| 209 | + "epoch": str(_get_epoch(state, unit)), |
| 210 | + "step": str( |
| 211 | + _get_step_phase_mapping(state, unit).get( |
| 212 | + state.active_phase.into_phase(), 0 |
| 213 | + ) |
| 214 | + ), |
| 215 | + } |
229 | 216 |
|
230 |
| - if hook == "on_train_end": |
231 |
| - # 2.1) Make sure that last checkpoint does not already exist |
232 |
| - if self._checkpoint_manager.does_checkpoint_exist( |
233 |
| - checkpoint_path, self._process_group |
| 217 | + with log_interval( |
| 218 | + "_generate_checkpoint_and_upkeep", metadata=log_interval_metadata |
| 219 | + ): |
| 220 | + # 1) generate checkpoint name |
| 221 | + epoch = _get_epoch(state, unit) |
| 222 | + step_mapping = _get_step_phase_mapping(state, unit) |
| 223 | + |
| 224 | + # 1.1) append metric data only if best_checkpoint_config is defined |
| 225 | + metric_data: Optional[MetricData] = None |
| 226 | + if self._best_checkpoint_config and ( |
| 227 | + metric_value := self._get_tracked_metric_value(cast(TTrainUnit, unit)) |
234 | 228 | ):
|
235 |
| - rank_zero_warn( |
236 |
| - "Final checkpoint already exists, skipping.", logger=logger |
| 229 | + metric_data = MetricData( |
| 230 | + name=none_throws(self._best_checkpoint_config).monitored_metric, |
| 231 | + value=metric_value, |
237 | 232 | )
|
| 233 | + |
| 234 | + checkpoint_path = self._checkpoint_manager.generate_checkpoint_path( |
| 235 | + epoch, |
| 236 | + step_mapping, |
| 237 | + metric_data, |
| 238 | + process_group=self._process_group, |
| 239 | + ) |
| 240 | + |
| 241 | + # 2) Determine if we should save checkpoint. This is a no-op for eval and predict entrypoints |
| 242 | + # since neither best_checkpoint_config nor keep_last_n_checkpoints are supported. |
| 243 | + if not self._checkpoint_manager.should_save_checkpoint(checkpoint_path): |
238 | 244 | return False
|
239 | 245 |
|
240 |
| - # 2.2) If doing fit without eval checkpointing, only consider training progress when |
241 |
| - # checking if last checkpoint exists. |
242 |
| - if ( |
243 |
| - state.entry_point == EntryPoint.FIT |
244 |
| - and self._save_every_n_eval_epochs is None |
245 |
| - and self._checkpoint_manager._ckpt_paths |
246 |
| - and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN] |
247 |
| - == cast(TTrainUnit, unit).train_progress.num_steps_completed |
| 246 | + if hook == "on_train_end": |
| 247 | + # 2.1) Make sure that last checkpoint does not already exist |
| 248 | + if self._checkpoint_manager.does_checkpoint_exist( |
| 249 | + checkpoint_path, self._process_group |
| 250 | + ): |
| 251 | + rank_zero_warn( |
| 252 | + "Final checkpoint already exists, skipping.", logger=logger |
| 253 | + ) |
| 254 | + return False |
| 255 | + |
| 256 | + # 2.2) If doing fit without eval checkpointing, only consider training progress when |
| 257 | + # checking if last checkpoint exists. |
| 258 | + if ( |
| 259 | + state.entry_point == EntryPoint.FIT |
| 260 | + and self._save_every_n_eval_epochs is None |
| 261 | + and self._checkpoint_manager._ckpt_paths |
| 262 | + and self._checkpoint_manager._ckpt_paths[-1].step[Phase.TRAIN] |
| 263 | + == cast(TTrainUnit, unit).train_progress.num_steps_completed |
| 264 | + ): |
| 265 | + rank_zero_info( |
| 266 | + "Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.", |
| 267 | + logger=logger, |
| 268 | + ) |
| 269 | + return False |
| 270 | + |
| 271 | + # 3) try to save checkpoint |
| 272 | + if not self._checkpoint_impl( |
| 273 | + state, unit, checkpoint_id=checkpoint_path.path, hook=hook |
248 | 274 | ):
|
249 |
| - rank_zero_info( |
250 |
| - "Omitting final checkpoint since train progress is unchanged, and eval checkpointing is not configured.", |
251 |
| - logger=logger, |
252 |
| - ) |
253 | 275 | return False
|
254 | 276 |
|
255 |
| - # 3) try to save checkpoint |
256 |
| - if not self._checkpoint_impl( |
257 |
| - state, unit, checkpoint_id=checkpoint_path.path, hook=hook |
258 |
| - ): |
259 |
| - return False |
| 277 | + # 4) track checkpoint and clean up surplus if needed |
| 278 | + self._checkpoint_manager.append_checkpoint(checkpoint_path) |
260 | 279 |
|
261 |
| - # 4) track checkpoint and clean up surplus if needed |
262 |
| - self._checkpoint_manager.append_checkpoint(checkpoint_path) |
| 280 | + # 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully |
| 281 | + unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path) |
263 | 282 |
|
264 |
| - # 5) invoke on_checkpoint_save callback on the unit since checkpoint was saved successfully |
265 |
| - unit.on_checkpoint_save(state, checkpoint_id=checkpoint_path.path) |
266 |
| - |
267 |
| - return True |
| 283 | + return True |
268 | 284 |
|
269 | 285 | def _get_tracked_metric_value(self, unit: TTrainUnit) -> Optional[float]:
|
270 | 286 | """
|
|
0 commit comments