Skip to content

Commit d44ef31

Browse files
authored
[train v2][doc] Update user guides for metrics, checkpoints, results, and experiment tracking (#51204)
Updates a few user guides mostly around the reporting of free-floating metrics that is no longer persisted by Ray Train. Ray Train only keeps around metrics that are attached to reported checkpoints. --------- Signed-off-by: Justin Yu <[email protected]>
1 parent 2a1add3 commit d44ef31

File tree

8 files changed

+138
-134
lines changed

8 files changed

+138
-134
lines changed

doc/source/train/doc_code/checkpoints.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,6 @@ def train_func(config):
155155
run_config=train.RunConfig(failure_config=train.FailureConfig(max_failures=1)),
156156
)
157157
result = trainer.fit()
158-
159-
# Seed a training run with a checkpoint using `resume_from_checkpoint`
160-
trainer = TorchTrainer(
161-
train_func,
162-
train_loop_config={"num_epochs": 5},
163-
scaling_config=ScalingConfig(num_workers=2),
164-
resume_from_checkpoint=result.checkpoint,
165-
)
166158
# __pytorch_restore_end__
167159

168160
# __checkpoint_from_single_worker_start__
@@ -249,7 +241,7 @@ def on_train_epoch_end(self, trainer, pl_module):
249241
should_checkpoint = trainer.current_epoch % 3 == 0
250242

251243
with TemporaryDirectory() as tmpdir:
252-
# Fetch metrics
244+
# Fetch metrics from `self.log(..)` in the LightningModule
253245
metrics = trainer.callback_metrics
254246
metrics = {k: v.item() for k, v in metrics.items()}
255247

@@ -289,21 +281,18 @@ def train_func():
289281
checkpoint = train.get_checkpoint()
290282
if checkpoint:
291283
with checkpoint.as_directory() as ckpt_dir:
292-
ckpt_path = os.path.join(ckpt_dir, "checkpoint.ckpt")
284+
ckpt_path = os.path.join(ckpt_dir, RayTrainReportCallback.CHECKPOINT_NAME)
293285
trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_path)
294286
else:
295287
trainer.fit(model, datamodule=datamodule)
296288

297289

298-
# Build a Ray Train Checkpoint
299-
# Suppose we have a Lightning checkpoint at `s3://bucket/ckpt_dir/checkpoint.ckpt`
300-
checkpoint = Checkpoint("s3://bucket/ckpt_dir")
301-
302-
# Resume training from checkpoint file
303290
ray_trainer = TorchTrainer(
304291
train_func,
305292
scaling_config=train.ScalingConfig(num_workers=2),
306-
resume_from_checkpoint=checkpoint,
293+
run_config=train.RunConfig(
294+
checkpoint_config=train.CheckpointConfig(num_to_keep=2),
295+
),
307296
)
308297
# __lightning_restore_example_end__
309298

doc/source/train/doc_code/key_concepts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,11 +122,12 @@ def train_fn(config):
122122
# __result_path_end__
123123

124124

125+
# TODO(justinvyu): Result.from_path is not supported in Train V2 yet.
125126
# __result_restore_start__
126-
from ray.train import Result
127+
# from ray.train import Result
127128

128-
restored_result = Result.from_path(result_path)
129-
print("Restored loss", result.metrics["loss"])
129+
# restored_result = Result.from_path(result_path)
130+
# print("Restored loss", result.metrics["loss"])
130131
# __result_restore_end__
131132

132133

doc/source/train/doc_code/torchmetrics_example.py renamed to doc/source/train/doc_code/metric_logging.py

Lines changed: 61 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
11
# flake8: noqa
22
# isort: skip_file
33

4-
# __start__
4+
import os
5+
6+
os.environ["RAY_TRAIN_V2_ENABLED"] = "1"
7+
8+
9+
# __torchmetrics_start__
510

611
# First, pip install torchmetrics
712
# This code is tested with torchmetrics==0.7.3 and torch==1.12.1
813

14+
import os
15+
import tempfile
16+
917
import ray.train.torch
1018
from ray import train
1119
from ray.train import ScalingConfig
@@ -62,13 +70,19 @@ def train_func(config):
6270
mape_collected = mape.compute().item()
6371
mean_valid_loss_collected = mean_valid_loss.compute().item()
6472

65-
train.report(
66-
{
67-
"mape_collected": mape_collected,
68-
"valid_loss": valid_loss,
69-
"mean_valid_loss_collected": mean_valid_loss_collected,
70-
}
71-
)
73+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
74+
torch.save(
75+
model.state_dict(), os.path.join(temp_checkpoint_dir, "model.pt")
76+
)
77+
78+
train.report(
79+
{
80+
"mape_collected": mape_collected,
81+
"valid_loss": valid_loss,
82+
"mean_valid_loss_collected": mean_valid_loss_collected,
83+
},
84+
checkpoint=train.Checkpoint.from_directory(temp_checkpoint_dir),
85+
)
7286

7387
# reset for next epoch
7488
mape.reset()
@@ -83,3 +97,42 @@ def train_func(config):
8397
result = trainer.fit()
8498
print(result.metrics["valid_loss"], result.metrics["mean_valid_loss_collected"])
8599
# 0.5109779238700867 0.5512474775314331
100+
101+
# __torchmetrics_end__
102+
103+
# __report_callback_start__
104+
import os
105+
106+
assert os.environ["RAY_TRAIN_V2_ENABLED"] == "1"
107+
108+
from typing import Any, Dict, List, Optional
109+
110+
import ray.train
111+
import ray.train.torch
112+
113+
114+
def train_fn_per_worker(config):
115+
# Free-floating metrics can be accessed from the callback below.
116+
ray.train.report({"rank": ray.train.get_context().get_world_rank()})
117+
118+
119+
class CustomMetricsCallback(ray.train.UserCallback):
120+
def after_report(
121+
self,
122+
run_context,
123+
metrics: List[Dict[str, Any]],
124+
checkpoint: Optional[ray.train.Checkpoint],
125+
):
126+
rank_0_metrics = metrics[0]
127+
print(rank_0_metrics)
128+
# Ex: Write metrics to a file...
129+
130+
131+
trainer = ray.train.torch.TorchTrainer(
132+
train_fn_per_worker,
133+
scaling_config=ray.train.ScalingConfig(num_workers=2),
134+
run_config=ray.train.RunConfig(callbacks=[CustomMetricsCallback()]),
135+
)
136+
trainer.fit()
137+
138+
# __report_callback_end__

doc/source/train/user-guides/checkpoints.rst

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,10 @@ Ray Train provides a way to snapshot training progress with :class:`Checkpoints
77

88
This is useful for:
99

10-
1. **Storing the best-performing model weights:** Save your model to persistent storage, and use it for downstream serving/inference.
11-
2. **Fault tolerance:** Handle node failures in a long-running training job on a cluster of pre-emptible machines/pods.
12-
3. **Distributed checkpointing:** When doing *model-parallel training*, Ray Train checkpointing provides an easy way to
13-
:ref:`upload model shards from each worker in parallel <train-distributed-checkpointing>`,
14-
without needing to gather the full model to a single node.
15-
4. **Integration with Ray Tune:** Checkpoint saving and loading is required by certain :ref:`Ray Tune schedulers <tune-schedulers>`.
16-
10+
1. **Storing the best-performing model weights:** Save your model to persistent storage, and use it for downstream serving or inference.
11+
2. **Fault tolerance:** Handle worker process and node failures in a long-running training job and leverage pre-emptible machines.
12+
3. **Distributed checkpointing:** Ray Train checkpointing can be used to
13+
:ref:`upload model shards from multiple workers in parallel. <train-distributed-checkpointing>`
1714

1815
.. _train-dl-saving-checkpoints:
1916

@@ -69,8 +66,8 @@ Then, the local temporary directory can be safely cleaned up to free up disk spa
6966
:start-after: __checkpoint_from_single_worker_start__
7067
:end-before: __checkpoint_from_single_worker_end__
7168

72-
If using parallel training strategies such as DeepSpeed Zero-3 and FSDP, where
73-
each worker only has a shard of the full-model, you should save and report a checkpoint
69+
If using parallel training strategies such as DeepSpeed Zero and FSDP, where
70+
each worker only has a shard of the full training state, you can save and report a checkpoint
7471
from each worker. See :ref:`train-distributed-checkpointing` for an example.
7572

7673

@@ -310,12 +307,10 @@ training state from a :class:`~ray.train.Checkpoint`.
310307
The :class:`Checkpoint <ray.train.Checkpoint>` to restore from can be accessed in the
311308
training function with :func:`ray.train.get_checkpoint <ray.train.get_checkpoint>`.
312309

313-
The checkpoint returned by :func:`ray.train.get_checkpoint <ray.train.get_checkpoint>` is populated in two ways:
314-
315-
1. It can be auto-populated as the latest reported checkpoint, e.g. during :ref:`automatic failure recovery <train-fault-tolerance>` or :ref:`on manual restoration <train-restore-guide>`.
316-
2. It can be manually populated by passing a checkpoint to the ``resume_from_checkpoint`` argument of a Ray :class:`Trainer <ray.train.trainer.BaseTrainer>`.
317-
This is useful for initializing a new training run with a previous run's checkpoint.
310+
The checkpoint returned by :func:`ray.train.get_checkpoint <ray.train.get_checkpoint>` is populated
311+
as the latest reported checkpoint during :ref:`automatic failure recovery <train-fault-tolerance>`.
318312

313+
See :ref:`train-fault-tolerance` for more details on restoration and fault tolerance.
319314

320315
.. tab-set::
321316

doc/source/train/user-guides/experiment-tracking.rst

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44
Experiment Tracking
55
===================
66

7-
.. note::
8-
This guide is relevant for all trainers in which you define a custom training loop.
9-
This includes :class:`TorchTrainer <ray.train.torch.TorchTrainer>` and
10-
:class:`TensorflowTrainer <ray.train.tensorflow.TensorflowTrainer>`.
11-
127
Most experiment tracking libraries work out-of-the-box with Ray Train.
138
This guide provides instructions on how to set up the code so that your favorite experiment tracking libraries
149
can work for distributed training with Ray Train. The end of the guide has common errors to aid in debugging
@@ -253,27 +248,6 @@ Refer to the tracking libraries' documentation for semantics.
253248

254249
When performing **fault-tolerant training** with auto-restoration, use a
255250
consistent ID to configure all tracking runs that logically belong to the same training run.
256-
One way to acquire an unique ID is with the following method:
257-
:meth:`ray.train.get_context().get_trial_id() <ray.train.context.TrainContext.get_trial_id>`.
258-
259-
.. testcode::
260-
:skipif: True
261-
262-
import ray
263-
from ray.train import ScalingConfig, RunConfig, FailureConfig
264-
from ray.train.torch import TorchTrainer
265-
266-
def train_func():
267-
if ray.train.get_context().get_world_rank() == 0:
268-
wandb.init(id=ray.train.get_context().get_trial_id())
269-
...
270-
271-
trainer = TorchTrainer(
272-
train_func,
273-
run_config=RunConfig(failure_config=FailureConfig(max_failures=3))
274-
)
275-
276-
trainer.fit()
277251

278252

279253
Step 3: Log metrics

doc/source/train/user-guides/hyperparameter-optimization.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,8 @@ Fault tolerance on the Ray Train side is configured and handled separately. See
140140
:end-before: __fault_tolerance_end__
141141

142142

143+
.. _train-with-tune-callbacks:
144+
143145
Advanced: Using Ray Tune callbacks
144146
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
145147

doc/source/train/user-guides/monitoring-logging.rst

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,60 +3,15 @@
33
Monitoring and Logging Metrics
44
==============================
55

6-
Ray Train provides an API for reporting intermediate
7-
results and checkpoints from the training function (run on distributed workers) up to the
8-
``Trainer`` (where your python script is executed) by calling ``train.report(metrics)``.
9-
The results will be collected from the distributed workers and passed to the driver to
10-
be logged and displayed.
6+
Ray Train provides an API for attaching metrics to :ref:`checkpoints <train-checkpointing>` from the training function by calling :func:`ray.train.report(metrics, checkpoint) <ray.train.report>`.
7+
The results will be collected from the distributed workers and passed to the Ray Train driver process for book-keeping.
118

12-
.. warning::
9+
The primary use-case for reporting is for metrics (accuracy, loss, etc.) at the end of each training epoch. See :ref:`train-dl-saving-checkpoints` for usage examples.
1310

14-
Only the results from rank 0 worker will be used. However, in order to ensure
15-
consistency, ``train.report()`` has to be called on each worker. If you
16-
want to aggregate results from multiple workers, see :ref:`train-aggregating-results`.
11+
Only the result reported by the rank 0 worker will be attached to the checkpoint.
12+
However, in order to ensure consistency, ``train.report()`` acts as a barrier and must be called on each worker.
13+
To aggregate results from multiple workers, see :ref:`train-aggregating-results`.
1714

18-
The primary use-case for reporting is for metrics (accuracy, loss, etc.) at
19-
the end of each training epoch.
20-
21-
.. tab-set::
22-
23-
.. tab-item:: PyTorch
24-
25-
.. testcode::
26-
27-
from ray import train
28-
29-
def train_func():
30-
...
31-
for i in range(num_epochs):
32-
result = model.train(...)
33-
train.report({"result": result})
34-
35-
.. tab-item:: PyTorch Lightning
36-
37-
In PyTorch Lightning, we use a callback to call ``train.report()``.
38-
39-
.. testcode::
40-
:skipif: True
41-
42-
from ray import train
43-
import pytorch_lightning as pl
44-
from pytorch_lightning.callbacks import Callback
45-
46-
class MyRayTrainReportCallback(Callback):
47-
def on_train_epoch_end(self, trainer, pl_module):
48-
metrics = trainer.callback_metrics
49-
metrics = {k: v.item() for k, v in metrics.items()}
50-
51-
train.report(metrics=metrics)
52-
53-
def train_func_per_worker():
54-
...
55-
trainer = pl.Trainer(
56-
# ...
57-
callbacks=[MyRayTrainReportCallback()]
58-
)
59-
trainer.fit()
6015

6116
.. _train-aggregating-results:
6217

@@ -77,6 +32,38 @@ metrics from multiple workers.
7732

7833
Here is an example of reporting both the aggregated R2 score and mean train and validation loss from all workers.
7934

80-
.. literalinclude:: ../doc_code/torchmetrics_example.py
35+
.. literalinclude:: ../doc_code/metric_logging.py
8136
:language: python
82-
:start-after: __start__
37+
:start-after: __torchmetrics_start__
38+
:end-before: __torchmetrics_end__
39+
40+
41+
.. _train-metric-only-reporting-deprecation:
42+
43+
(Deprecated) Reporting free-floating metrics
44+
--------------------------------------------
45+
46+
Reporting metrics with ``ray.train.report(metrics, checkpoint=None)`` from every worker writes the metrics to a Ray Tune log file (``progress.csv``, ``result.json``)
47+
and is accessible via the ``Result.metrics_dataframe`` on the :class:`~ray.train.Result` returned by ``trainer.fit()``.
48+
49+
As of Ray 2.43, this behavior is deprecated and will not be supported in Ray Train V2,
50+
which is an overhaul of Ray Train's implementation and select APIs.
51+
52+
Ray Train V2 only keeps a slim set of experiment tracking features that are necessary for fault tolerance, so it does not support reporting free-floating metrics that are not attached to checkpoints.
53+
The recommendation for metric tracking is to report metrics directly from the workers to experiment tracking tools such as MLFlow and WandB.
54+
See :ref:`train-experiment-tracking-native` for examples.
55+
56+
In Ray Train V2, reporting only metrics from all workers is a no-op. However, it is still possible to access the results reported by all workers to implement custom metric-handling logic.
57+
58+
.. literalinclude:: ../doc_code/metric_logging.py
59+
:language: python
60+
:start-after: __report_callback_start__
61+
:end-before: __report_callback_end__
62+
63+
64+
To use Ray Tune :class:`Callbacks <ray.tune.Callback>` that depend on free-floating metrics reported by workers, :ref:`run Ray Train as a single Ray Tune trial. <train-with-tune-callbacks>`
65+
66+
See the following resources for more information:
67+
68+
* `Train V2 REP <https://github.com/ray-project/enhancements/blob/main/reps/2024-10-18-train-tune-api-revamp/2024-10-18-train-tune-api-revamp.md>`_: Technical details about the API changes in Train V2
69+
* `Train V2 Migration Guide <https://github.com/ray-project/ray/issues/49454>`_: Full migration guide for Train V2

0 commit comments

Comments
 (0)