From eb0e9c3acf1d7c8f11e0e40743f993d82fd24dea Mon Sep 17 00:00:00 2001 From: Desroziers Date: Thu, 24 Sep 2020 14:04:15 +0200 Subject: [PATCH 1/4] imrpove r2 for ddp --- ignite/contrib/metrics/regression/r2_score.py | 34 +++++-- .../metrics/regression/test_r2_score.py | 96 +++++++++++++++++++ 2 files changed, 121 insertions(+), 9 deletions(-) diff --git a/ignite/contrib/metrics/regression/r2_score.py b/ignite/contrib/metrics/regression/r2_score.py index bc5fa47f33a2..7f65cb015a08 100644 --- a/ignite/contrib/metrics/regression/r2_score.py +++ b/ignite/contrib/metrics/regression/r2_score.py @@ -1,7 +1,10 @@ +from typing import Callable, Union + import torch from ignite.contrib.metrics.regression._base import _BaseRegression from ignite.exceptions import NotComputableError +from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce class R2Score(_BaseRegression): @@ -18,21 +21,34 @@ class R2Score(_BaseRegression): - `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`. """ + def __init__( + self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu") + ): + self._num_examples = None + self._sum_of_errors = None + self._y_sq_sum = None + self._y_sum = None + super(R2Score, self).__init__(output_transform, device) + + @reinit__is_reduced def reset(self): - self._num_examples = 0 - self._sum_of_errors = 0 - self._y_sq_sum = 0 - self._y_sum = 0 + self._num_examples = torch.tensor(0, device=self._device) + self._sum_of_errors = torch.tensor(0.0, device=self._device) + self._y_sq_sum = torch.tensor(0.0, device=self._device) + self._y_sum = torch.tensor(0.0, device=self._device) def _update(self, output): y_pred, y = output self._num_examples += y.shape[0] - self._sum_of_errors += torch.sum(torch.pow(y_pred - y, 2)).item() + self._sum_of_errors += torch.sum(torch.pow(y_pred - y, 2)).to(self._device) - self._y_sum += torch.sum(y).item() - self._y_sq_sum += torch.sum(torch.pow(y, 2)).item() + self._y_sum += torch.sum(y).to(self._device) + self._y_sq_sum += torch.sum(torch.pow(y, 2)).to(self._device) + @sync_all_reduce("_num_examples", "_sum_of_errors", "_y_sq_sum", "_y_sum") def compute(self): - if self._num_examples == 0: + if self._num_examples.item() == 0: raise NotComputableError("R2Score must have at least one example before it can be computed.") - return 1 - self._sum_of_errors / (self._y_sq_sum - (self._y_sum ** 2) / self._num_examples) + return 1 - self._sum_of_errors.item() / ( + self._y_sq_sum.item() - (self._y_sum.item() ** 2) / self._num_examples.item() + ) diff --git a/tests/ignite/contrib/metrics/regression/test_r2_score.py b/tests/ignite/contrib/metrics/regression/test_r2_score.py index 9f3e48bdecca..64642f0e204f 100644 --- a/tests/ignite/contrib/metrics/regression/test_r2_score.py +++ b/tests/ignite/contrib/metrics/regression/test_r2_score.py @@ -1,8 +1,12 @@ +import os + import numpy as np import pytest import torch from sklearn.metrics import r2_score +import ignite.distributed as idist + from ignite.contrib.metrics.regression import R2Score from ignite.engine import Engine @@ -86,3 +90,95 @@ def update_fn(engine, batch): r_squared = engine.run(data, max_epochs=1).metrics["r2_score"] assert r2_score(np_y, np_y_pred) == pytest.approx(r_squared) + + +def _test_distrib_compute(device): + rank = idist.get_rank() + + def _test(metric_device): + metric_device = torch.device(metric_device) + m = R2Score(device=metric_device) + torch.manual_seed(10 + rank) + + y_pred = torch.randint(0, 10, size=(10,), device=device).float() + y = torch.randint(0, 10, size=(10,), device=device).float() + + m.update((y_pred, y)) + + # gather y_pred, y + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + + np_y_pred = y_pred.cpu().numpy() + np_y = y.cpu().numpy() + res = m.compute() + assert r2_score(np_y, np_y_pred) == pytest.approx(res) + + for _ in range(3): + _test("cpu") + if device.type != "xla": + _test(idist.device()) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") +def test_distrib_gpu(distributed_context_single_node_nccl): + device = torch.device("cuda:{}".format(distributed_context_single_node_nccl["local_rank"])) + _test_distrib_compute(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +def test_distrib_cpu(distributed_context_single_node_gloo): + + device = torch.device("cpu") + _test_distrib_compute(device) + + +@pytest.mark.distributed +@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support") +@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc") +def test_distrib_hvd(gloo_hvd_executor): + + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() + + gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): + device = torch.device("cpu") + _test_distrib_compute(device) + + +@pytest.mark.multinode_distributed +@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed") +def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): + device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"])) + _test_distrib_compute(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_single_device_xla(): + device = idist.device() + _test_distrib_compute(device) + + +def _test_distrib_xla_nprocs(index): + device = idist.device() + _test_distrib_compute(device) + + +@pytest.mark.tpu +@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars") +@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package") +def test_distrib_xla_nprocs(xmp_executor): + n = int(os.environ["NUM_TPU_WORKERS"]) + xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n) From 29f7aff38ce44ff60f333f18d1af72385c13af98 Mon Sep 17 00:00:00 2001 From: AutoPEP8 <> Date: Thu, 24 Sep 2020 12:08:33 +0000 Subject: [PATCH 2/4] autopep8 fix --- tests/ignite/contrib/metrics/regression/test_r2_score.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/ignite/contrib/metrics/regression/test_r2_score.py b/tests/ignite/contrib/metrics/regression/test_r2_score.py index 64642f0e204f..3a6ecebe25ee 100644 --- a/tests/ignite/contrib/metrics/regression/test_r2_score.py +++ b/tests/ignite/contrib/metrics/regression/test_r2_score.py @@ -6,7 +6,6 @@ from sklearn.metrics import r2_score import ignite.distributed as idist - from ignite.contrib.metrics.regression import R2Score from ignite.engine import Engine From 44f000eea96829e31913e7d022b90327aec432a6 Mon Sep 17 00:00:00 2001 From: Desroziers Date: Mon, 28 Sep 2020 08:49:00 +0200 Subject: [PATCH 3/4] _num_examples type is scalar --- ignite/contrib/metrics/regression/r2_score.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ignite/contrib/metrics/regression/r2_score.py b/ignite/contrib/metrics/regression/r2_score.py index 7f65cb015a08..9d1b176b2322 100644 --- a/ignite/contrib/metrics/regression/r2_score.py +++ b/ignite/contrib/metrics/regression/r2_score.py @@ -32,7 +32,7 @@ def __init__( @reinit__is_reduced def reset(self): - self._num_examples = torch.tensor(0, device=self._device) + self._num_examples = 0 self._sum_of_errors = torch.tensor(0.0, device=self._device) self._y_sq_sum = torch.tensor(0.0, device=self._device) self._y_sum = torch.tensor(0.0, device=self._device) @@ -47,8 +47,8 @@ def _update(self, output): @sync_all_reduce("_num_examples", "_sum_of_errors", "_y_sq_sum", "_y_sum") def compute(self): - if self._num_examples.item() == 0: + if self._num_examples == 0: raise NotComputableError("R2Score must have at least one example before it can be computed.") return 1 - self._sum_of_errors.item() / ( - self._y_sq_sum.item() - (self._y_sum.item() ** 2) / self._num_examples.item() + self._y_sq_sum.item() - (self._y_sum.item() ** 2) / self._num_examples ) From f5edbbee34668903e5613388172fcaa302967be5 Mon Sep 17 00:00:00 2001 From: AutoPEP8 <> Date: Mon, 28 Sep 2020 06:50:43 +0000 Subject: [PATCH 4/4] autopep8 fix --- ignite/contrib/metrics/regression/r2_score.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/ignite/contrib/metrics/regression/r2_score.py b/ignite/contrib/metrics/regression/r2_score.py index 9d1b176b2322..1c9fd9a72d08 100644 --- a/ignite/contrib/metrics/regression/r2_score.py +++ b/ignite/contrib/metrics/regression/r2_score.py @@ -49,6 +49,4 @@ def _update(self, output): def compute(self): if self._num_examples == 0: raise NotComputableError("R2Score must have at least one example before it can be computed.") - return 1 - self._sum_of_errors.item() / ( - self._y_sq_sum.item() - (self._y_sum.item() ** 2) / self._num_examples - ) + return 1 - self._sum_of_errors.item() / (self._y_sq_sum.item() - (self._y_sum.item() ** 2) / self._num_examples)