Skip to content

Commit 722b387

Browse files
diego-urgellfacebook-github-bot
authored andcommitted
Save checkpoint in on_predict_end hook (#973)
Summary: Pull Request resolved: #973 Reviewed By: galrotem Differential Revision: D69865408 fbshipit-source-id: 82ba22fcfffcc2689089ae49170e95e94fcc361c
1 parent 45e1138 commit 722b387

File tree

4 files changed

+18
-3
lines changed

4 files changed

+18
-3
lines changed

tests/framework/callbacks/test_base_checkpointer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,11 @@ def _checkpoint_impl_side_effect(
339339
expected_ckpts = [
340340
f"{temp_dir}/epoch_0_predict_step_{i}" for i in range(1, 11)
341341
]
342+
343+
expected_ckpts.append(
344+
f"{temp_dir}/epoch_1_predict_step_10"
345+
) # We always expect checkpoint on predict end
346+
342347
self.assertEqual(ckpt_container, expected_ckpts)
343348

344349
@unittest.mock.patch("sys.stdout", new_callable=io.StringIO)

tests/framework/callbacks/test_dcp_saver.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -536,19 +536,25 @@ def test_save_restore_predict(self) -> None:
536536
expected_ckpts = [
537537
"epoch_0_predict_step_2",
538538
"epoch_0_predict_step_4",
539+
"epoch_1_predict_step_5",
539540
]
540541

541542
self.assertCountEqual(generated_ckpts, expected_ckpts)
542543

543-
ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
544-
self.assertEqual(ckpt_path, os.path.join(temp_dir, expected_ckpts[-1]))
544+
latest_ckpt_path = none_throws(get_latest_checkpoint_path(temp_dir))
545+
self.assertEqual(
546+
latest_ckpt_path, os.path.join(temp_dir, expected_ckpts[-1])
547+
)
545548

546549
expected_keys = [
547550
"predict_progress",
548551
"predict_dataloader",
549552
"output_mean",
550553
]
551554

555+
# Check keys on a checkpoint other than the latest since it won't have dataloader state
556+
ckpt_path = f"{temp_dir}/{expected_ckpts[0]}"
557+
552558
storage_reader = FsspecReader(ckpt_path)
553559
metadata = storage_reader.read_metadata()
554560
self.assertCountEqual(

torchtnt/framework/callbacks/base_checkpointer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,9 @@ def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
378378

379379
self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_step_end")
380380

381+
def on_predict_end(self, state: State, unit: TPredictUnit) -> None:
382+
self._generate_checkpoint_and_upkeep(state, unit, hook="on_predict_end")
383+
381384
def _disable_ckpt_optimality_tracking(self) -> None:
382385
"""
383386
Disables checkpoint optimality tracking. This means that best_checkpoint and keep_last_n_checkpoints

torchtnt/framework/callbacks/dcp_saver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def _checkpoint_impl(
169169
"on_eval_epoch_end",
170170
"on_eval_step_end",
171171
"on_predict_step_end",
172+
"on_predict_end",
172173
]:
173174
raise RuntimeError(f"Unexpected hook encountered '{hook}'")
174175

@@ -178,7 +179,7 @@ def _checkpoint_impl(
178179
intra_epoch = "step_end" in hook or (
179180
"on_eval_epoch_end" == hook and state.entry_point == EntryPoint.FIT
180181
)
181-
curr_snapshot_wait = hook == "on_train_end"
182+
curr_snapshot_wait = hook in ("on_train_end", "on_predict_end")
182183

183184
if planner is None:
184185
planner = DefaultSavePlanner()

0 commit comments

Comments
 (0)