Skip to content

Improve R2Score metric for DDP #1318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Sep 28, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions ignite/contrib/metrics/regression/r2_score.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Callable, Union

import torch

from ignite.contrib.metrics.regression._base import _BaseRegression
from ignite.exceptions import NotComputableError
from ignite.metrics.metric import reinit__is_reduced, sync_all_reduce


class R2Score(_BaseRegression):
Expand All @@ -18,21 +21,34 @@ class R2Score(_BaseRegression):
- `y` and `y_pred` must be of same shape `(N, )` or `(N, 1)` and of type `float32`.
"""

def __init__(
self, output_transform: Callable = lambda x: x, device: Union[str, torch.device] = torch.device("cpu")
):
self._num_examples = None
self._sum_of_errors = None
self._y_sq_sum = None
self._y_sum = None
super(R2Score, self).__init__(output_transform, device)

@reinit__is_reduced
def reset(self):
self._num_examples = 0
self._sum_of_errors = 0
self._y_sq_sum = 0
self._y_sum = 0
self._num_examples = torch.tensor(0, device=self._device)
self._sum_of_errors = torch.tensor(0.0, device=self._device)
self._y_sq_sum = torch.tensor(0.0, device=self._device)
self._y_sum = torch.tensor(0.0, device=self._device)

def _update(self, output):
y_pred, y = output
self._num_examples += y.shape[0]
self._sum_of_errors += torch.sum(torch.pow(y_pred - y, 2)).item()
self._sum_of_errors += torch.sum(torch.pow(y_pred - y, 2)).to(self._device)

self._y_sum += torch.sum(y).item()
self._y_sq_sum += torch.sum(torch.pow(y, 2)).item()
self._y_sum += torch.sum(y).to(self._device)
self._y_sq_sum += torch.sum(torch.pow(y, 2)).to(self._device)

@sync_all_reduce("_num_examples", "_sum_of_errors", "_y_sq_sum", "_y_sum")
def compute(self):
if self._num_examples == 0:
if self._num_examples.item() == 0:
raise NotComputableError("R2Score must have at least one example before it can be computed.")
return 1 - self._sum_of_errors / (self._y_sq_sum - (self._y_sum ** 2) / self._num_examples)
return 1 - self._sum_of_errors.item() / (
self._y_sq_sum.item() - (self._y_sum.item() ** 2) / self._num_examples.item()
)
95 changes: 95 additions & 0 deletions tests/ignite/contrib/metrics/regression/test_r2_score.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os

import numpy as np
import pytest
import torch
from sklearn.metrics import r2_score

import ignite.distributed as idist
from ignite.contrib.metrics.regression import R2Score
from ignite.engine import Engine

Expand Down Expand Up @@ -86,3 +89,95 @@ def update_fn(engine, batch):
r_squared = engine.run(data, max_epochs=1).metrics["r2_score"]

assert r2_score(np_y, np_y_pred) == pytest.approx(r_squared)


def _test_distrib_compute(device):
rank = idist.get_rank()

def _test(metric_device):
metric_device = torch.device(metric_device)
m = R2Score(device=metric_device)
torch.manual_seed(10 + rank)

y_pred = torch.randint(0, 10, size=(10,), device=device).float()
y = torch.randint(0, 10, size=(10,), device=device).float()

m.update((y_pred, y))

# gather y_pred, y
y_pred = idist.all_gather(y_pred)
y = idist.all_gather(y)

np_y_pred = y_pred.cpu().numpy()
np_y = y.cpu().numpy()
res = m.compute()
assert r2_score(np_y, np_y_pred) == pytest.approx(res)

for _ in range(3):
_test("cpu")
if device.type != "xla":
_test(idist.device())


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_distrib_gpu(distributed_context_single_node_nccl):
device = torch.device("cuda:{}".format(distributed_context_single_node_nccl["local_rank"]))
_test_distrib_compute(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
def test_distrib_cpu(distributed_context_single_node_gloo):

device = torch.device("cpu")
_test_distrib_compute(device)


@pytest.mark.distributed
@pytest.mark.skipif(not idist.has_hvd_support, reason="Skip if no Horovod dist support")
@pytest.mark.skipif("WORLD_SIZE" in os.environ, reason="Skip if launched as multiproc")
def test_distrib_hvd(gloo_hvd_executor):

device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count()

gloo_hvd_executor(_test_distrib_compute, (device,), np=nproc, do_init=True)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_cpu(distributed_context_multi_node_gloo):
device = torch.device("cpu")
_test_distrib_compute(device)


@pytest.mark.multinode_distributed
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
def test_multinode_distrib_gpu(distributed_context_multi_node_nccl):
device = torch.device("cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]))
_test_distrib_compute(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" in os.environ, reason="Skip if NUM_TPU_WORKERS is in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_single_device_xla():
device = idist.device()
_test_distrib_compute(device)


def _test_distrib_xla_nprocs(index):
device = idist.device()
_test_distrib_compute(device)


@pytest.mark.tpu
@pytest.mark.skipif("NUM_TPU_WORKERS" not in os.environ, reason="Skip if no NUM_TPU_WORKERS in env vars")
@pytest.mark.skipif(not idist.has_xla_support, reason="Skip if no PyTorch XLA package")
def test_distrib_xla_nprocs(xmp_executor):
n = int(os.environ["NUM_TPU_WORKERS"])
xmp_executor(_test_distrib_xla_nprocs, args=(), nprocs=n)