diff --git a/docs/source/metrics.rst b/docs/source/metrics.rst index bc2da2e12cba..47b7f9f3aab4 100644 --- a/docs/source/metrics.rst +++ b/docs/source/metrics.rst @@ -120,15 +120,15 @@ specific condition (e.g. ignore user-defined classes): class CustomAccuracy(Metric): - def __init__(self, ignored_class, output_transform=lambda x: x): + def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"): self.ignored_class = ignored_class self._num_correct = None self._num_examples = None - super(CustomAccuracy, self).__init__(output_transform=output_transform) + super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device) @reinit__is_reduced def reset(self): - self._num_correct = 0 + self._num_correct = torch.tensor(0, device=self._device) self._num_examples = 0 super(CustomAccuracy, self).reset() @@ -144,14 +144,16 @@ specific condition (e.g. ignore user-defined classes): indices = indices[mask] correct = torch.eq(indices, y).view(-1) - self._num_correct += torch.sum(correct).item() + # We must detach tensors before adding them to the internal variables. In this case, torch.eq is not + # differentiable, so the computation graph is detached implicitly + self._num_correct += torch.sum(correct).to(self._device) self._num_examples += correct.shape[0] @sync_all_reduce("_num_examples", "_num_correct") def compute(self): if self._num_examples == 0: raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.') - return self._num_correct / self._num_examples + return self._num_correct.item() / self._num_examples We imported necessary classes as :class:`~ignite.metrics.Metric`, :class:`~ignite.exceptions.NotComputableError` and @@ -159,6 +161,11 @@ decorators to adapt the metric for distributed setting. In ``reset`` method, we and ``_num_examples`` which are used to compute the custom metric. In ``updated`` method we define how to update the internal variables. And finally in ``compute`` method, we compute metric value. +Notice that ``_num_correct`` is a tensor, since in ``update`` we accumulate tensor values. ``_num_examples`` is a python +scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before +adding them to the internal variables. Accuracy is not differentiable (specifically the ``torch.eq`` call), so it +is implicitly detached from the computation graph. + We can check this implementation in a simple case: .. code-block:: python diff --git a/ignite/metrics/accumulation.py b/ignite/metrics/accumulation.py index 0de2ee469cc7..5a000c38911e 100644 --- a/ignite/metrics/accumulation.py +++ b/ignite/metrics/accumulation.py @@ -31,7 +31,9 @@ class VariableAccumulation(Metric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ @@ -70,6 +72,9 @@ def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None: output = output.to(self._device) self.accumulator = self._op(self.accumulator, output) + if isinstance(self.accumulator, torch.Tensor): + self.accumulator = self.accumulator.to(self._device) + if hasattr(output, "shape"): self.num_examples += output.shape[0] if len(output.shape) > 1 else 1 else: @@ -114,8 +119,9 @@ class Average(VariableAccumulation): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - device (str of torch.device, optional): optional device specification for internal storage. - + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ def __init__( @@ -160,7 +166,9 @@ class GeometricAverage(VariableAccumulation): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ diff --git a/ignite/metrics/accuracy.py b/ignite/metrics/accuracy.py index 79ba5e4c1d80..fe53db66f828 100644 --- a/ignite/metrics/accuracy.py +++ b/ignite/metrics/accuracy.py @@ -122,7 +122,9 @@ def thresholded_output_transform(output): form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. is_multilabel (bool, optional): flag to use in multilabel case. By default, False. - device (str of torch.device, optional): unused argument. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ diff --git a/ignite/metrics/confusion_matrix.py b/ignite/metrics/confusion_matrix.py index 574e7e6a1105..3d179a458d62 100644 --- a/ignite/metrics/confusion_matrix.py +++ b/ignite/metrics/confusion_matrix.py @@ -30,7 +30,9 @@ class ConfusionMatrix(Metric): :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. Note: In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only diff --git a/ignite/metrics/fbeta.py b/ignite/metrics/fbeta.py index 6af40b7234d0..eb6776a17eb8 100644 --- a/ignite/metrics/fbeta.py +++ b/ignite/metrics/fbeta.py @@ -28,7 +28,9 @@ def Fbeta( output_transform (callable, optional): a callable that is used to transform the :class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the form expected by the metric. It is used only if precision or recall are not provided. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. Returns: MetricsLambda, F-beta metric diff --git a/ignite/metrics/loss.py b/ignite/metrics/loss.py index b0e7d1955fd7..305ad7c2cb1e 100644 --- a/ignite/metrics/loss.py +++ b/ignite/metrics/loss.py @@ -26,7 +26,9 @@ class Loss(Metric): keywords arguments. If extra keywords arguments are provided they are passed to `loss_fn`. batch_size (callable): a callable taking a target tensor that returns the first dimension size (usually the batch size). - device (str of torch.device, optional): unused argument. + device (str or torch.device): specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. """ diff --git a/ignite/metrics/metric.py b/ignite/metrics/metric.py index 64b18169d375..42f3897625c1 100644 --- a/ignite/metrics/metric.py +++ b/ignite/metrics/metric.py @@ -122,7 +122,9 @@ class Metric(metaclass=ABCMeta): form expected by the metric. This can be useful if, for example, you have a multi-output model and you want to compute the metric with respect to one of the outputs. By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. - device (str of torch.device, optional): optional device specification for internal storage. + device (str or torch.device): specifies which device updates are accumulated on. Setting the + metric's device to be the same as your ``update`` arguments ensures the ``update`` method is + non-blocking. By default, CPU. """ diff --git a/ignite/metrics/precision.py b/ignite/metrics/precision.py index 7754c47f47cc..055f11c9c391 100644 --- a/ignite/metrics/precision.py +++ b/ignite/metrics/precision.py @@ -122,7 +122,9 @@ def thresholded_output_transform(output): in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average parameter should be True and the average is computed across samples, instead of classes. - device (str of torch.device, optional): unused argument. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ diff --git a/ignite/metrics/recall.py b/ignite/metrics/recall.py index d446583f6bd6..a094557249c5 100644 --- a/ignite/metrics/recall.py +++ b/ignite/metrics/recall.py @@ -60,7 +60,9 @@ def thresholded_output_transform(output): in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case). is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average parameter should be True and the average is computed across samples, instead of classes. - device (str of torch.device, optional): unused argument. + device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's + device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By + default, CPU. """ diff --git a/ignite/metrics/running_average.py b/ignite/metrics/running_average.py index 2094e3866ada..0fa1216c7940 100644 --- a/ignite/metrics/running_average.py +++ b/ignite/metrics/running_average.py @@ -20,7 +20,11 @@ class RunningAverage(Metric): corresponds the output of process function. Otherwise it should be None. epoch_bound (boolean, optional): whether the running average should be reset after each epoch (defaults to True). - device (str of torch.device, optional): unused argument. + device (str or torch.device, optional): specifies which device updates are accumulated on. Should be + None when ``src`` is an instance of :class:`~ignite.metrics.Metric`, as the running average will + use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value + from the metric is a tensor. + Examples: @@ -63,6 +67,7 @@ def __init__( self.src = src self._get_src_value = self._get_metric_value self.iteration_completed = self._metric_iteration_completed + device = src._device else: if output_transform is None: raise ValueError( @@ -71,6 +76,8 @@ def __init__( ) self._get_src_value = self._get_output_value self.update = self._output_update + if device is None: + device = torch.device("cpu") self.alpha = alpha self.epoch_bound = epoch_bound @@ -118,5 +125,5 @@ def _metric_iteration_completed(self, engine: Engine) -> None: @reinit__is_reduced def _output_update(self, output: Union[torch.Tensor, float]) -> None: if isinstance(output, torch.Tensor): - output = output.detach().clone() + output = output.detach().to(self._device, copy=True) self.src = output diff --git a/tests/ignite/metrics/test_accumulation.py b/tests/ignite/metrics/test_accumulation.py index abd71f3ad1df..9aaf38bf9844 100644 --- a/tests/ignite/metrics/test_accumulation.py +++ b/tests/ignite/metrics/test_accumulation.py @@ -198,105 +198,120 @@ def compute_mean_std(engine, batch): def _test_distrib_variable_accumulation(device): + def _test(metric_device): + mean_var = VariableAccumulation(lambda a, x: a + x, device=metric_device) + y_true = torch.rand(100, device=device, dtype=torch.float64) - mean_var = VariableAccumulation(lambda a, x: a + x, device=device) - y_true = torch.rand(100, device=device, dtype=torch.float64) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + y_true = idist.all_reduce(y_true) + a, n = mean_var.compute() + assert a.item() == pytest.approx(y_true.sum().item()) + assert n == len(y_true) * idist.get_world_size() + # check if call compute twice + a, n = mean_var.compute() + assert a.item() == pytest.approx(y_true.sum().item()) + assert n == len(y_true) * idist.get_world_size() - y_true = idist.all_reduce(y_true) - a, n = mean_var.compute() - assert a.item() == pytest.approx(y_true.sum().item()) - assert n == len(y_true) * idist.get_world_size() - # check if call compute twice - a, n = mean_var.compute() - assert a.item() == pytest.approx(y_true.sum().item()) - assert n == len(y_true) * idist.get_world_size() + mean_var = VariableAccumulation(lambda a, x: a + x, device=metric_device) + y_true = torch.rand(50, 10, device=device, dtype=torch.float64) - mean_var = VariableAccumulation(lambda a, x: a + x, device=device) - y_true = torch.rand(50, 10, device=device, dtype=torch.float64) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + y_true = idist.all_reduce(y_true) + a, n = mean_var.compute() + assert n == len(y_true) * idist.get_world_size() + np.testing.assert_almost_equal(a.cpu().numpy(), y_true.sum(dim=0).cpu().numpy(), decimal=4) + a, n = mean_var.compute() + assert n == len(y_true) * idist.get_world_size() + np.testing.assert_almost_equal(a.cpu().numpy(), y_true.sum(dim=0).cpu().numpy(), decimal=4) - y_true = idist.all_reduce(y_true) - a, n = mean_var.compute() - assert n == len(y_true) * idist.get_world_size() - np.testing.assert_almost_equal(a.cpu().numpy(), y_true.sum(dim=0).cpu().numpy(), decimal=4) - a, n = mean_var.compute() - assert n == len(y_true) * idist.get_world_size() - np.testing.assert_almost_equal(a.cpu().numpy(), y_true.sum(dim=0).cpu().numpy(), decimal=4) + # check multiple random inputs as random exact occurencies are rare + for _ in range(3): + _test("cpu") + _test(idist.device()) def _test_distrib_average(device): + def _test(metric_device): + with pytest.raises(NotComputableError): + v = Average(device=metric_device) + v.compute() - with pytest.raises(NotComputableError): - v = Average(device=device) - v.compute() + mean_var = Average(device=metric_device) + y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() + y_true = y_true.to(device) - mean_var = Average(device=device) - y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() - y_true = y_true.to(device) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + m = mean_var.compute() - m = mean_var.compute() + y_true = idist.all_reduce(y_true) + assert m.item() == pytest.approx(y_true.mean().item() / idist.get_world_size()) - y_true = idist.all_reduce(y_true) - assert m.item() == pytest.approx(y_true.mean().item() / idist.get_world_size()) + mean_var = Average(device=metric_device) + y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() + y_true = y_true.to(device) - mean_var = Average(device=device) - y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() - y_true = y_true.to(device) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + m = mean_var.compute() - m = mean_var.compute() + y_true = idist.all_reduce(y_true) + np.testing.assert_almost_equal( + m.cpu().numpy(), y_true.mean(dim=0).cpu().numpy() / idist.get_world_size(), decimal=5 + ) - y_true = idist.all_reduce(y_true) - np.testing.assert_almost_equal( - m.cpu().numpy(), y_true.mean(dim=0).cpu().numpy() / idist.get_world_size(), decimal=5 - ) + # check multiple random inputs as random exact occurencies are rare + for _ in range(3): + _test("cpu") + _test(idist.device()) def _test_distrib_geom_average(device): + def _test(metric_device): + with pytest.raises(NotComputableError): + v = GeometricAverage(device=metric_device) + v.compute() - with pytest.raises(NotComputableError): - v = GeometricAverage(device=device) - v.compute() + mean_var = GeometricAverage(device=metric_device) + y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() + y_true = y_true.to(device) - mean_var = GeometricAverage(device=device) - y_true = torch.rand(100, dtype=torch.float64) + torch.randint(0, 10, size=(100,)).double() - y_true = y_true.to(device) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + m = mean_var.compute() + log_y_true = torch.log(y_true) + log_y_true = idist.all_reduce(log_y_true) + assert m.item() == pytest.approx(torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).item()) - m = mean_var.compute() - log_y_true = torch.log(y_true) - log_y_true = idist.all_reduce(log_y_true) - assert m.item() == pytest.approx(torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).item()) + mean_var = GeometricAverage(device=metric_device) + y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() + y_true = y_true.to(device) - mean_var = GeometricAverage(device=device) - y_true = torch.rand(100, 10, dtype=torch.float64) + torch.randint(0, 10, size=(100, 10)).double() - y_true = y_true.to(device) + for y in y_true: + mean_var.update(y) - for y in y_true: - mean_var.update(y) + m = mean_var.compute() + log_y_true = torch.log(y_true) + log_y_true = idist.all_reduce(log_y_true) + np.testing.assert_almost_equal( + m.cpu().numpy(), torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).cpu().numpy(), decimal=5 + ) - m = mean_var.compute() - log_y_true = torch.log(y_true) - log_y_true = idist.all_reduce(log_y_true) - np.testing.assert_almost_equal( - m.cpu().numpy(), torch.exp(log_y_true.mean(dim=0) / idist.get_world_size()).cpu().numpy(), decimal=5 - ) + # check multiple random inputs as random exact occurencies are rare + for _ in range(3): + _test("cpu") + _test(idist.device()) def _test_distrib_integration(device): - def _test(metric_cls, true_result_fn, tol=1e-5): + def _test(metric_cls, true_result_fn, metric_device, tol=1e-5): size = 100 custom_variable = 10.0 + 5.0 * torch.rand(size, 12, dtype=torch.float64) @@ -307,7 +322,7 @@ def update_fn(engine, batch): engine = Engine(update_fn) - custom_var_mean = metric_cls(output_transform=lambda output: output[1], device=device) + custom_var_mean = metric_cls(output_transform=lambda output: output[1], device=metric_device) custom_var_mean.attach(engine, "agg_custom_var") state = engine.run([0] * size) @@ -326,7 +341,7 @@ def update_fn(engine, batch): engine = Engine(update_fn) - custom_var_mean = metric_cls(output_transform=lambda output: output[1], device=device) + custom_var_mean = metric_cls(output_transform=lambda output: output[1], device=metric_device) custom_var_mean.attach(engine, "agg_custom_var") state = engine.run([0] * size) @@ -342,8 +357,25 @@ def _geom_mean(y_true): np_t = log_y_true.cpu().numpy() return np.exp(np.mean(np_t, axis=0) / idist.get_world_size()) - _test(Average, _mean) - _test(GeometricAverage, _geom_mean, tol=1e-4) + for metric_device in ["cpu", idist.device()]: + _test(Average, _mean, metric_device) + _test(GeometricAverage, _geom_mean, metric_device, tol=1e-4) + + +def _test_distrib_accumulator_device(device): + + for metric_device in [torch.device("cpu"), idist.device()]: + + m = VariableAccumulation(lambda a, x: x, device=metric_device) + assert m._device == metric_device + assert m.accumulator.device == metric_device, "{}:{} vs {}:{}".format( + type(m.accumulator.device), m.accumulator.device, type(metric_device), metric_device + ) + + m.update(torch.tensor(1, device=device)) + assert m.accumulator.device == metric_device, "{}:{} vs {}:{}".format( + type(m.accumulator.device), m.accumulator.device, type(metric_device), metric_device + ) @pytest.mark.distributed @@ -356,6 +388,7 @@ def test_distrib_gpu(distributed_context_single_node_nccl): _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -367,6 +400,7 @@ def test_distrib_cpu(distributed_context_single_node_gloo): _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @@ -378,6 +412,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -392,6 +427,7 @@ def test_distrib_hvd(gloo_hvd_executor): gloo_hvd_executor(_test_distrib_average, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_geom_average, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_integration, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -403,6 +439,7 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -414,6 +451,7 @@ def test_distrib_single_device_xla(): _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): @@ -422,6 +460,7 @@ def _test_distrib_xla_nprocs(index): _test_distrib_average(device) _test_distrib_geom_average(device) _test_distrib_integration(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_confusion_matrix.py b/tests/ignite/metrics/test_confusion_matrix.py index 2960a59b9d02..55867cdebe92 100644 --- a/tests/ignite/metrics/test_confusion_matrix.py +++ b/tests/ignite/metrics/test_confusion_matrix.py @@ -547,68 +547,90 @@ def test_dice_coefficient(): def _test_distrib_multiclass_images(device): + def _test(metric_device): + num_classes = 3 + cm = ConfusionMatrix(num_classes=num_classes, device=metric_device) - num_classes = 3 - cm = ConfusionMatrix(num_classes=num_classes, device=device) + y_true, y_pred = get_y_true_y_pred() - y_true, y_pred = get_y_true_y_pred() + # Compute confusion matrix with sklearn + true_res = confusion_matrix(y_true.reshape(-1), y_pred.reshape(-1)) - # Compute confusion matrix with sklearn - true_res = confusion_matrix(y_true.reshape(-1), y_pred.reshape(-1)) + th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred) + th_y_true = th_y_true.to(device) + th_y_logits = th_y_logits.to(device) - th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred) - th_y_true = th_y_true.to(device) - th_y_logits = th_y_logits.to(device) + # Update metric + output = (th_y_logits, th_y_true) + cm.update(output) - # Update metric - output = (th_y_logits, th_y_true) - cm.update(output) + res = cm.compute().cpu().numpy() / idist.get_world_size() - res = cm.compute().cpu().numpy() / idist.get_world_size() + assert np.all(true_res == res) - assert np.all(true_res == res) + # Another test on batch of 2 images + num_classes = 3 + cm = ConfusionMatrix(num_classes=num_classes, device=metric_device) + + # Create a batch of two images: + th_y_true1 = torch.from_numpy(y_true).reshape(1, 30, 30) + th_y_true2 = torch.from_numpy(y_true.transpose()).reshape(1, 30, 30) + th_y_true = torch.cat([th_y_true1, th_y_true2], dim=0) + th_y_true = th_y_true.to(device) + + # Create a batch of 2 logits tensors + y_probas = np.ones((3, 30, 30)) * -10 + y_probas[0, (y_pred == 0)] = 720 + y_probas[1, (y_pred == 1)] = 720 + y_probas[2, (y_pred == 2)] = 768 + th_y_logits1 = torch.from_numpy(y_probas).reshape(1, 3, 30, 30) + + y_probas = np.ones((3, 30, 30)) * -10 + y_probas[0, (y_pred.transpose() == 0)] = 720 + y_probas[1, (y_pred.transpose() == 2)] = 720 + y_probas[2, (y_pred.transpose() == 1)] = 768 + th_y_logits2 = torch.from_numpy(y_probas).reshape(1, 3, 30, 30) + + th_y_logits = torch.cat([th_y_logits1, th_y_logits2], dim=0) + # check update if input is on another device + th_y_logits = th_y_logits.to(device) + + # Update metric & compute + output = (th_y_logits, th_y_true) + cm.update(output) + res = cm.compute().cpu().numpy() - # Another test on batch of 2 images - num_classes = 3 - cm = ConfusionMatrix(num_classes=num_classes, device=device) + # Compute confusion matrix with sklearn + th_y_true = idist.all_gather(th_y_true) + th_y_logits = idist.all_gather(th_y_logits) - # Create a batch of two images: - th_y_true1 = torch.from_numpy(y_true).reshape(1, 30, 30) - th_y_true2 = torch.from_numpy(y_true.transpose()).reshape(1, 30, 30) - th_y_true = torch.cat([th_y_true1, th_y_true2], dim=0) - th_y_true = th_y_true.to(device) + np_y_true = th_y_true.cpu().numpy().reshape(-1) + np_y_pred = np.argmax(th_y_logits.cpu().numpy(), axis=1).reshape(-1) + true_res = confusion_matrix(np_y_true, np_y_pred) - # Create a batch of 2 logits tensors - y_probas = np.ones((3, 30, 30)) * -10 - y_probas[0, (y_pred == 0)] = 720 - y_probas[1, (y_pred == 1)] = 720 - y_probas[2, (y_pred == 2)] = 768 - th_y_logits1 = torch.from_numpy(y_probas).reshape(1, 3, 30, 30) + assert np.all(true_res == res) - y_probas = np.ones((3, 30, 30)) * -10 - y_probas[0, (y_pred.transpose() == 0)] = 720 - y_probas[1, (y_pred.transpose() == 2)] = 720 - y_probas[2, (y_pred.transpose() == 1)] = 768 - th_y_logits2 = torch.from_numpy(y_probas).reshape(1, 3, 30, 30) + _test("cpu") + _test(idist.device()) - th_y_logits = torch.cat([th_y_logits1, th_y_logits2], dim=0) - # check update if input is on another device - th_y_logits = th_y_logits.to(device) - # Update metric & compute - output = (th_y_logits, th_y_true) - cm.update(output) - res = cm.compute().cpu().numpy() +def _test_distrib_accumulator_device(device): - # Compute confusion matrix with sklearn - th_y_true = idist.all_gather(th_y_true) - th_y_logits = idist.all_gather(th_y_logits) + for metric_device in [torch.device("cpu"), idist.device()]: - np_y_true = th_y_true.cpu().numpy().reshape(-1) - np_y_pred = np.argmax(th_y_logits.cpu().numpy(), axis=1).reshape(-1) - true_res = confusion_matrix(np_y_true, np_y_pred) + cm = ConfusionMatrix(num_classes=3, device=metric_device) + assert cm._device == metric_device + assert cm.confusion_matrix.device == metric_device, "{}:{} vs {}:{}".format( + type(cm.confusion_matrix.device), cm._num_correct.device, type(metric_device), metric_device + ) - assert np.all(true_res == res) + y_true, y_pred = get_y_true_y_pred() + th_y_true, th_y_logits = compute_th_y_true_y_logits(y_true, y_pred) + cm.update((th_y_logits, th_y_true)) + + assert cm.confusion_matrix.device == metric_device, "{}:{} vs {}:{}".format( + type(cm.confusion_matrix.device), acc._num_correct.device, type(metric_device), metric_device + ) @pytest.mark.distributed @@ -618,6 +640,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -626,6 +649,7 @@ def test_distrib_cpu(distributed_context_single_node_gloo): device = "cpu" _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -637,6 +661,7 @@ def test_distrib_hvd(gloo_hvd_executor): nproc = 4 if not torch.cuda.is_available() else torch.cuda.device_count() gloo_hvd_executor(_test_distrib_multiclass_images, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -645,6 +670,7 @@ def test_distrib_hvd(gloo_hvd_executor): def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @@ -653,6 +679,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -661,11 +688,13 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): def test_distrib_single_device_xla(): device = idist.device() _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_multiclass_images(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_loss.py b/tests/ignite/metrics/test_loss.py index 4739bca16d1f..e8a405ba93eb 100644 --- a/tests/ignite/metrics/test_loss.py +++ b/tests/ignite/metrics/test_loss.py @@ -76,47 +76,56 @@ def test_reset(): def _test_distrib_compute_on_criterion(device): + def _test(metric_device): + criterion = nn.NLLLoss().to(device) + loss = Loss(criterion, device=metric_device) - criterion = nn.NLLLoss().to(device) - loss = Loss(criterion, device=device) - - y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], device=device).log() - y = torch.tensor([2, 2], device=device).long() - loss.update((y_pred, y)) - n = loss._num_examples - assert n == len(y) - res = loss.compute() - assert n * idist.get_world_size() == loss._num_examples - - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) - true_loss_value = criterion(y_pred, y) - assert_almost_equal(res, true_loss_value.item()) + y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]], device=device).log() + y = torch.tensor([2, 2], device=device).long() + loss.update((y_pred, y)) + n = loss._num_examples + assert n == len(y) + res = loss.compute() + assert n * idist.get_world_size() == loss._num_examples + + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + true_loss_value = criterion(y_pred, y) + assert_almost_equal(res, true_loss_value.item()) + + loss.reset() + y_pred = torch.tensor([[0.1, 0.3, 0.6], [0.6, 0.2, 0.2], [0.2, 0.7, 0.1]], device=device).log() + y = torch.tensor([2, 0, 2], device=device).long() + loss.update((y_pred, y)) + n = loss._num_examples + res = loss.compute() + assert n * idist.get_world_size() == loss._num_examples - loss.reset() - y_pred = torch.tensor([[0.1, 0.3, 0.6], [0.6, 0.2, 0.2], [0.2, 0.7, 0.1]], device=device).log() - y = torch.tensor([2, 0, 2], device=device).long() - loss.update((y_pred, y)) - n = loss._num_examples - res = loss.compute() - assert n * idist.get_world_size() == loss._num_examples + y_pred = idist.all_gather(y_pred) + y = idist.all_gather(y) + true_loss_value = criterion(y_pred, y) + assert_almost_equal(res, true_loss_value.item()) - y_pred = idist.all_gather(y_pred) - y = idist.all_gather(y) - true_loss_value = criterion(y_pred, y) - assert_almost_equal(res, true_loss_value.item()) + _test("cpu") + _test(idist.device()) def _test_distrib_sum_device(device): - device = torch.device(device) - loss = Loss(nll_loss, device=device) - assert loss._device == device - y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]]).log() - y = torch.tensor([2, 2]).long() - loss.update((y_pred, y)) + for metric_device in [torch.device("cpu"), idist.device()]: + loss = Loss(nll_loss, device=metric_device) + assert loss._device == metric_device + assert loss._sum.device == metric_device, "{}:{} vs {}:{}".format( + type(loss._sum.device), loss._sum.device, type(metric_device), metric_device + ) + + y_pred = torch.tensor([[0.1, 0.4, 0.5], [0.1, 0.7, 0.2]]).log() + y = torch.tensor([2, 2]).long() + loss.update((y_pred, y)) - assert loss._sum.device == device + assert loss._sum.device == metric_device, "{}:{} vs {}:{}".format( + type(loss._sum.device), loss._sum.device, type(metric_device), metric_device + ) def test_sum_detached(): diff --git a/tests/ignite/metrics/test_mean_absolute_error.py b/tests/ignite/metrics/test_mean_absolute_error.py index 557c4182f3b4..5e197e14bb9e 100644 --- a/tests/ignite/metrics/test_mean_absolute_error.py +++ b/tests/ignite/metrics/test_mean_absolute_error.py @@ -49,31 +49,47 @@ def update(engine, i): y_true[i * s + offset * rank : (i + 1) * s + offset * rank], ) - engine = Engine(update) + def _test(metric_device): + engine = Engine(update) - m = MeanAbsoluteError() - m.attach(engine, "mae") + m = MeanAbsoluteError(device=metric_device) + m.attach(engine, "mae") - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - assert "mae" in engine.state.metrics - res = engine.state.metrics["mae"] + assert "mae" in engine.state.metrics + res = engine.state.metrics["mae"] - true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy())) + true_res = np.mean(np.abs((y_true - y_preds).cpu().numpy())) - assert pytest.approx(res) == true_res + assert pytest.approx(res) == true_res + + _test("cpu") + _test(idist.device()) def _test_distrib_accumulator_device(device): - device = torch.device(device) - mae = MeanAbsoluteError(device=device) - assert mae._device == device - y_pred = torch.tensor([[2.0], [-2.0]]) - y = torch.zeros(2) - mae.update((y_pred, y)) - assert mae._sum_of_absolute_errors.device == device + for metric_device in [torch.device("cpu"), idist.device()]: + mae = MeanAbsoluteError(device=metric_device) + assert mae._device == metric_device + assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format( + type(mae._sum_of_absolute_errors.device), + mae._sum_of_absolute_errors.device, + type(metric_device), + metric_device, + ) + + y_pred = torch.tensor([[2.0], [-2.0]]) + y = torch.zeros(2) + mae.update((y_pred, y)) + assert mae._sum_of_absolute_errors.device == metric_device, "{}:{} vs {}:{}".format( + type(mae._sum_of_absolute_errors.device), + mae._sum_of_absolute_errors.device, + type(metric_device), + metric_device, + ) def test_accumulator_detached(): diff --git a/tests/ignite/metrics/test_mean_pairwise_distance.py b/tests/ignite/metrics/test_mean_pairwise_distance.py index f5ea4e731ef7..df14e362dc26 100644 --- a/tests/ignite/metrics/test_mean_pairwise_distance.py +++ b/tests/ignite/metrics/test_mean_pairwise_distance.py @@ -52,42 +52,53 @@ def update(engine, i): y_true[i * s + offset * rank : (i + 1) * s + offset * rank, ...], ) - engine = Engine(update) - - m = MeanPairwiseDistance() - m.attach(engine, "mpwd") + def _test(metric_device): + engine = Engine(update) + + m = MeanPairwiseDistance(device=metric_device) + m.attach(engine, "mpwd") + + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) + + assert "mpwd" in engine.state.metrics + res = engine.state.metrics["mpwd"] + + true_res = [] + for i in range(n_iters * idist.get_world_size()): + true_res.append( + torch.pairwise_distance( + y_true[i * s : (i + 1) * s, ...], y_preds[i * s : (i + 1) * s, ...], p=m._p, eps=m._eps + ) + .cpu() + .numpy() + ) + true_res = np.array(true_res).ravel() + true_res = true_res.mean() - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + assert pytest.approx(res) == true_res - assert "mpwd" in engine.state.metrics - res = engine.state.metrics["mpwd"] + _test("cpu") + _test(idist.device()) - true_res = [] - for i in range(n_iters * idist.get_world_size()): - true_res.append( - torch.pairwise_distance( - y_true[i * s : (i + 1) * s, ...], y_preds[i * s : (i + 1) * s, ...], p=m._p, eps=m._eps - ) - .cpu() - .numpy() - ) - true_res = np.array(true_res).ravel() - true_res = true_res.mean() - assert pytest.approx(res) == true_res +def _test_distrib_accumulator_device(device): + for metric_device in [torch.device("cpu"), idist.device()]: -def _test_distrib_accumulator_device(device): - device = torch.device(device) - mpd = MeanPairwiseDistance(device=device) - assert mpd._device == device + mpd = MeanPairwiseDistance(device=metric_device) + assert mpd._device == metric_device + assert mpd._sum_of_distances.device == metric_device, "{}:{} vs {}:{}".format( + type(mpd._sum_of_distances.device), mpd._sum_of_distances.device, type(metric_device), metric_device + ) - y_pred = torch.Tensor([[3.0, 4.0], [-3.0, -4.0]]) - y = torch.zeros(2, 2) - mpd.update((y_pred, y)) + y_pred = torch.Tensor([[3.0, 4.0], [-3.0, -4.0]]) + y = torch.zeros(2, 2) + mpd.update((y_pred, y)) - assert mpd._sum_of_distances.device == device + assert mpd._sum_of_distances.device == metric_device, "{}:{} vs {}:{}".format( + type(mpd._sum_of_distances.device), mpd._sum_of_distances.device, type(metric_device), metric_device + ) def test_accumulator_detached(): diff --git a/tests/ignite/metrics/test_mean_squared_error.py b/tests/ignite/metrics/test_mean_squared_error.py index 08552836531f..e8bb18777e84 100644 --- a/tests/ignite/metrics/test_mean_squared_error.py +++ b/tests/ignite/metrics/test_mean_squared_error.py @@ -49,31 +49,49 @@ def update(engine, i): y_true[i * s + offset * rank : (i + 1) * s + offset * rank], ) - engine = Engine(update) + def _test(metric_device): + engine = Engine(update) - m = MeanSquaredError() - m.attach(engine, "mse") + m = MeanSquaredError(device=metric_device) + m.attach(engine, "mse") - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - assert "mse" in engine.state.metrics - res = engine.state.metrics["mse"] + assert "mse" in engine.state.metrics + res = engine.state.metrics["mse"] - true_res = np.mean(np.power((y_true - y_preds).cpu().numpy(), 2.0)) + true_res = np.mean(np.power((y_true - y_preds).cpu().numpy(), 2.0)) - assert pytest.approx(res, rel=tol) == true_res + assert pytest.approx(res, rel=tol) == true_res + + _test("cpu") + _test(idist.device()) def _test_distrib_accumulator_device(device): - device = torch.device(device) - mse = MeanSquaredError(device=device) - assert mse._device == device - y_pred = torch.tensor([[2.0], [-2.0]]) - y = torch.zeros(2) - mse.update((y_pred, y)) - assert mse._sum_of_squared_errors.device == device + for metric_device in [torch.device("cpu"), idist.device()]: + + device = torch.device(device) + mse = MeanSquaredError(device=metric_device) + assert mse._device == metric_device + assert mse._sum_of_squared_errors.device == metric_device, "{}:{} vs {}:{}".format( + type(mse._sum_of_squared_errors.device), + mse._sum_of_squared_errors.device, + type(metric_device), + metric_device, + ) + + y_pred = torch.tensor([[2.0], [-2.0]]) + y = torch.zeros(2) + mse.update((y_pred, y)) + assert mse._sum_of_squared_errors.device == metric_device, "{}:{} vs {}:{}".format( + type(mse._sum_of_squared_errors.device), + mse._sum_of_squared_errors.device, + type(metric_device), + metric_device, + ) def test_accumulator_detached(): diff --git a/tests/ignite/metrics/test_metrics_lambda.py b/tests/ignite/metrics/test_metrics_lambda.py index a6a0902496dd..3e98b9088967 100644 --- a/tests/ignite/metrics/test_metrics_lambda.py +++ b/tests/ignite/metrics/test_metrics_lambda.py @@ -325,7 +325,7 @@ def _test_distrib_integration(device): batch_size = 10 n_classes = 10 - def _test(): + def _test(metric_device): y_true = np.arange(0, n_iters * batch_size * idist.get_world_size(), dtype="int64") % n_classes y_pred = 0.2 * np.random.rand(n_iters * batch_size * idist.get_world_size(), n_classes) for i in range(n_iters * batch_size * idist.get_world_size()): @@ -345,8 +345,8 @@ def update_fn(engine, i): evaluator = Engine(update_fn) - precision = Precision(average=False, device=device) - recall = Recall(average=False, device=device) + precision = Precision(average=False, device=metric_device) + recall = Recall(average=False, device=metric_device) def Fbeta(r, p, beta): return torch.mean((1 + beta ** 2) * p * r / (beta ** 2 * p + r)).item() @@ -366,8 +366,9 @@ def Fbeta(r, p, beta): assert f1_true == approx(state.metrics["f1"]) assert 1.0 + f1_true == approx(state.metrics["ff1"]) - for _ in range(5): - _test() + for _ in range(3): + _test("cpu") + _test(idist.device()) @pytest.mark.distributed diff --git a/tests/ignite/metrics/test_precision.py b/tests/ignite/metrics/test_precision.py index 9ea0741fc835..1fb5b0fd4a5f 100644 --- a/tests/ignite/metrics/test_precision.py +++ b/tests/ignite/metrics/test_precision.py @@ -722,7 +722,7 @@ def _test_distrib_integration_multiclass(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(average, n_epochs): + def _test(average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -739,7 +739,7 @@ def update(engine, i): engine = Engine(update) - pr = Precision(average=average) + pr = Precision(average=average, device=metric_device) pr.attach(engine, "pr") data = list(range(n_iters)) @@ -748,7 +748,7 @@ def update(engine, i): assert "pr" in engine.state.metrics res = engine.state.metrics["pr"] if isinstance(res, torch.Tensor): - assert res.device.type == "cpu" + assert res.device == metric_device res = res.cpu().numpy() true_res = precision_score( @@ -758,10 +758,11 @@ def update(engine, i): assert pytest.approx(res) == true_res for _ in range(2): - _test(average=True, n_epochs=1) - _test(average=True, n_epochs=2) - _test(average=False, n_epochs=1) - _test(average=False, n_epochs=2) + for metric_device in [torch.device("cpu"), idist.device()]: + _test(average=True, n_epochs=1, metric_device=metric_device) + _test(average=True, n_epochs=2, metric_device=metric_device) + _test(average=False, n_epochs=1, metric_device=metric_device) + _test(average=False, n_epochs=2, metric_device=metric_device) def _test_distrib_integration_multilabel(device): @@ -771,7 +772,7 @@ def _test_distrib_integration_multilabel(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(average, n_epochs): + def _test(average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -813,8 +814,9 @@ def update(engine, i): assert pytest.approx(res) == true_res for _ in range(2): - _test(average=True, n_epochs=1) - _test(average=True, n_epochs=2) + for metric_device in ["cpu", idist.device()]: + _test(average=True, n_epochs=1, metric_device=metric_device) + _test(average=True, n_epochs=2, metric_device=metric_device) if idist.get_world_size() > 1: with pytest.warns( @@ -835,38 +837,57 @@ def update(engine, i): def _test_distrib_accumulator_device(device): # Binary accuracy on input of shape (N, 1) or (N, ) - device = torch.device(device) - def _test(average): - pr = Precision(average=average, device=device) - assert pr._device == device + def _test(average, metric_device): + pr = Precision(average=average, device=metric_device) + assert pr._device == metric_device + # Since the shape of the accumulated amount isn't known before the first update + # call, the internal variables aren't tensors on the right device yet. y_pred = torch.randint(0, 2, size=(10,)) y = torch.randint(0, 2, size=(10,)).long() pr.update((y_pred, y)) - assert pr._true_positives.device == device - assert pr._positives.device == device + assert pr._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._true_positives.device), pr._true_positives.device, type(metric_device), metric_device + ) + assert pr._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._positives.device), pr._positives.device, type(metric_device), metric_device + ) - _test(True) - _test(False) + for metric_device in [torch.device("cpu"), idist.device()]: + _test(True, metric_device=metric_device) + _test(False, metric_device=metric_device) def _test_distrib_multilabel_accumulator_device(device): # Multiclass input data of shape (N, ) and (N, C) - device = torch.device(device) - def _test(average): - pr = Precision(is_multilabel=True, average=average, device=device) + def _test(average, metric_device): + pr = Precision(is_multilabel=True, average=average, device=metric_device) + + assert pr._device == metric_device + assert pr._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._true_positives.device), pr._true_positives.device, type(metric_device), metric_device + ) + assert pr._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._positives.device), pr._positives.device, type(metric_device), metric_device + ) + y_pred = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() pr.update((y_pred, y)) - assert pr._true_positives.device == device - assert pr._positives.device == device + assert pr._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._true_positives.device), pr._true_positives.device, type(metric_device), metric_device + ) + assert pr._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(pr._positives.device), pr._positives.device, type(metric_device), metric_device + ) - _test(True) - _test(False) + for metric_device in [torch.device("cpu"), idist.device()]: + _test(True, metric_device=metric_device) + _test(False, metric_device=metric_device) @pytest.mark.distributed diff --git a/tests/ignite/metrics/test_recall.py b/tests/ignite/metrics/test_recall.py index b783b2ec83ce..a170f3c0ac6d 100644 --- a/tests/ignite/metrics/test_recall.py +++ b/tests/ignite/metrics/test_recall.py @@ -722,7 +722,7 @@ def _test_distrib_integration_multiclass(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(average, n_epochs): + def _test(average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -739,7 +739,7 @@ def update(engine, i): engine = Engine(update) - re = Recall(average=average) + re = Recall(average=average, device=metric_device) re.attach(engine, "re") data = list(range(n_iters)) @@ -748,7 +748,7 @@ def update(engine, i): assert "re" in engine.state.metrics res = engine.state.metrics["re"] if isinstance(res, torch.Tensor): - assert res.device.type == "cpu" + assert res.device == metric_device res = res.cpu().numpy() true_res = recall_score( @@ -758,10 +758,11 @@ def update(engine, i): assert pytest.approx(res) == true_res for _ in range(2): - _test(average=True, n_epochs=1) - _test(average=True, n_epochs=2) - _test(average=False, n_epochs=1) - _test(average=False, n_epochs=2) + for metric_device in [torch.device("cpu"), idist.device()]: + _test(average=True, n_epochs=1, metric_device=metric_device) + _test(average=True, n_epochs=2, metric_device=metric_device) + _test(average=False, n_epochs=1, metric_device=metric_device) + _test(average=False, n_epochs=2, metric_device=metric_device) def _test_distrib_integration_multilabel(device): @@ -771,7 +772,7 @@ def _test_distrib_integration_multilabel(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(average, n_epochs): + def _test(average, n_epochs, metric_device): n_iters = 60 s = 16 n_classes = 7 @@ -788,7 +789,7 @@ def update(engine, i): engine = Engine(update) - re = Recall(average=average, is_multilabel=True) + re = Recall(average=average, is_multilabel=True, device=metric_device) re.attach(engine, "re") data = list(range(n_iters)) @@ -813,8 +814,9 @@ def update(engine, i): assert pytest.approx(res) == true_res for _ in range(2): - _test(average=True, n_epochs=1) - _test(average=True, n_epochs=2) + for metric_device in ["cpu", idist.device()]: + _test(average=True, n_epochs=1, metric_device=metric_device) + _test(average=True, n_epochs=2, metric_device=metric_device) if idist.get_world_size() > 1: with pytest.warns( @@ -835,38 +837,57 @@ def update(engine, i): def _test_distrib_accumulator_device(device): # Binary accuracy on input of shape (N, 1) or (N, ) - device = torch.device(device) - def _test(average): - re = Recall(average=average, device=device) - assert re._device == device + def _test(average, metric_device): + re = Recall(average=average, device=metric_device) + assert re._device == metric_device + # Since the shape of the accumulated amount isn't known before the first update + # call, the internal variables aren't tensors on the right device yet. y_reed = torch.randint(0, 2, size=(10,)) y = torch.randint(0, 2, size=(10,)).long() re.update((y_reed, y)) - assert re._true_positives.device == device - assert re._positives.device == device + assert re._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._true_positives.device), re._true_positives.device, type(metric_device), metric_device + ) + assert re._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._positives.device), re._positives.device, type(metric_device), metric_device + ) - _test(True) - _test(False) + for metric_device in [torch.device("cpu"), idist.device()]: + _test(True, metric_device=metric_device) + _test(False, metric_device=metric_device) def _test_distrib_multilabel_accumulator_device(device): # Multiclass input data of shape (N, ) and (N, C) - device = torch.device(device) - def _test(average): - re = Recall(is_multilabel=True, average=average, device=device) + def _test(average, metric_device): + re = Recall(is_multilabel=True, average=average, device=metric_device) + + assert re._device == metric_device + assert re._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._true_positives.device), re._true_positives.device, type(metric_device), metric_device + ) + assert re._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._positives.device), re._positives.device, type(metric_device), metric_device + ) + y_reed = torch.randint(0, 2, size=(10, 4, 20, 23)) y = torch.randint(0, 2, size=(10, 4, 20, 23)).long() re.update((y_reed, y)) - assert re._true_positives.device == device - assert re._positives.device == device + assert re._true_positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._true_positives.device), re._true_positives.device, type(metric_device), metric_device + ) + assert re._positives.device == metric_device, "{}:{} vs {}:{}".format( + type(re._positives.device), re._positives.device, type(metric_device), metric_device + ) - _test(True) - _test(False) + for metric_device in [torch.device("cpu"), idist.device()]: + _test(True, metric_device=metric_device) + _test(False, metric_device=metric_device) @pytest.mark.distributed diff --git a/tests/ignite/metrics/test_root_mean_squared_error.py b/tests/ignite/metrics/test_root_mean_squared_error.py index 878ef9df367d..b7068be11867 100644 --- a/tests/ignite/metrics/test_root_mean_squared_error.py +++ b/tests/ignite/metrics/test_root_mean_squared_error.py @@ -46,25 +46,29 @@ def _test_distrib_integration(device, tol=1e-6): def update(engine, i): return y_preds[i * s : (i + 1) * s], y_true[i * s + offset * rank : (i + 1) * s + offset * rank] - engine = Engine(update) + def _test(metric_device): + engine = Engine(update) - m = RootMeanSquaredError() - m.attach(engine, "rmse") + m = RootMeanSquaredError(device=metric_device) + m.attach(engine, "rmse") - data = list(range(n_iters)) - engine.run(data=data, max_epochs=1) + data = list(range(n_iters)) + engine.run(data=data, max_epochs=1) - assert "rmse" in engine.state.metrics - res = engine.state.metrics["rmse"] + assert "rmse" in engine.state.metrics + res = engine.state.metrics["rmse"] - y_preds_full = [] - for i in range(idist.get_world_size()): - y_preds_full.append((i + 1) * torch.ones(offset)) - y_preds_full = torch.stack(y_preds_full).to(device).flatten() + y_preds_full = [] + for i in range(idist.get_world_size()): + y_preds_full.append((i + 1) * torch.ones(offset)) + y_preds_full = torch.stack(y_preds_full).to(device).flatten() - true_res = np.sqrt(np.mean(np.square((y_true - y_preds_full).cpu().numpy()))) + true_res = np.sqrt(np.mean(np.square((y_true - y_preds_full).cpu().numpy()))) - assert pytest.approx(res, rel=tol) == true_res + assert pytest.approx(res, rel=tol) == true_res + + _test("cpu") + _test(idist.device()) @pytest.mark.distributed diff --git a/tests/ignite/metrics/test_running_average.py b/tests/ignite/metrics/test_running_average.py index ab1528e10596..1d729b97c5d0 100644 --- a/tests/ignite/metrics/test_running_average.py +++ b/tests/ignite/metrics/test_running_average.py @@ -305,61 +305,84 @@ def _test_distrib_on_metric(device): batch_size = 10 n_classes = 10 - data = list(range(n_iters)) - np.random.seed(12) - all_y_true_batch_values = np.random.randint( - 0, n_classes, size=(idist.get_world_size(), n_epochs * n_iters, batch_size) - ) - all_y_pred_batch_values = np.random.rand(idist.get_world_size(), n_epochs * n_iters, batch_size, n_classes) + def _test(metric_device): + data = list(range(n_iters)) + np.random.seed(12) + all_y_true_batch_values = np.random.randint( + 0, n_classes, size=(idist.get_world_size(), n_epochs * n_iters, batch_size) + ) + all_y_pred_batch_values = np.random.rand(idist.get_world_size(), n_epochs * n_iters, batch_size, n_classes) - y_true_batch_values = iter(all_y_true_batch_values[rank, ...]) - y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...]) + y_true_batch_values = iter(all_y_true_batch_values[rank, ...]) + y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...]) - def update_fn(engine, batch): - y_true_batch = next(y_true_batch_values) - y_pred_batch = next(y_pred_batch_values) - return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) + def update_fn(engine, batch): + y_true_batch = next(y_true_batch_values) + y_pred_batch = next(y_pred_batch_values) + return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) - trainer = Engine(update_fn) - alpha = 0.98 + trainer = Engine(update_fn) + alpha = 0.98 - acc_metric = RunningAverage( - Accuracy(output_transform=lambda x: [x[0], x[1]], device=device), alpha=alpha, epoch_bound=False - ) - acc_metric.attach(trainer, "running_avg_accuracy") + acc_metric = RunningAverage( + Accuracy(output_transform=lambda x: [x[0], x[1]], device=metric_device), alpha=alpha, epoch_bound=False + ) + acc_metric.attach(trainer, "running_avg_accuracy") + + running_avg_acc = [ + None, + ] + true_acc_metric = Accuracy(device=metric_device) + + @trainer.on(Events.ITERATION_COMPLETED) + def manual_running_avg_acc(engine): + i = engine.state.iteration - 1 + + true_acc_metric.reset() + for j in range(idist.get_world_size()): + output = ( + torch.from_numpy(all_y_pred_batch_values[j, i, :, :]), + torch.from_numpy(all_y_true_batch_values[j, i, :]), + ) + true_acc_metric.update(output) + + batch_acc = true_acc_metric._num_correct.item() * 1.0 / true_acc_metric._num_examples + + if running_avg_acc[0] is None: + running_avg_acc[0] = batch_acc + else: + running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc + engine.state.running_avg_acc = running_avg_acc[0] + + @trainer.on(Events.ITERATION_COMPLETED) + def assert_equal_running_avg_acc_values(engine): + assert engine.state.running_avg_acc == engine.state.metrics["running_avg_accuracy"], "{} vs {}".format( + engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"] + ) - running_avg_acc = [ - None, - ] - true_acc_metric = Accuracy(device=device) + trainer.run(data, max_epochs=3) - @trainer.on(Events.ITERATION_COMPLETED) - def manual_running_avg_acc(engine): - i = engine.state.iteration - 1 + _test("cpu") + _test(idist.device()) - true_acc_metric.reset() - for j in range(idist.get_world_size()): - output = ( - torch.from_numpy(all_y_pred_batch_values[j, i, :, :]), - torch.from_numpy(all_y_true_batch_values[j, i, :]), - ) - true_acc_metric.update(output) - batch_acc = true_acc_metric._num_correct.item() * 1.0 / true_acc_metric._num_examples +def _test_distrib_accumulator_device(device): - if running_avg_acc[0] is None: - running_avg_acc[0] = batch_acc - else: - running_avg_acc[0] = running_avg_acc[0] * alpha + (1.0 - alpha) * batch_acc - engine.state.running_avg_acc = running_avg_acc[0] + for metric_device in [torch.device("cpu"), idist.device()]: - @trainer.on(Events.ITERATION_COMPLETED) - def assert_equal_running_avg_acc_values(engine): - assert engine.state.running_avg_acc == engine.state.metrics["running_avg_accuracy"], "{} vs {}".format( - engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"] - ) + # Don't test the src=Metric case because compute() returns a scalar, + # so the metric doesn't accumulate on the device specified + avg = RunningAverage(output_transform=lambda x: x, device=metric_device) + assert avg._device == metric_device + # Value is None until the first update then compute call - trainer.run(data, max_epochs=3) + for _ in range(3): + avg.update(torch.tensor(1.0, device=device)) + avg.compute() + + assert avg._value.device == metric_device, "{}:{} vs {}:{}".format( + type(avg._value.device), avg._value.device, type(metric_device), metric_device + ) @pytest.mark.distributed @@ -370,6 +393,7 @@ def test_distrib_gpu(local_rank, distributed_context_single_node_nccl): device = "cuda:{}".format(local_rank) _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -379,6 +403,7 @@ def test_distrib_cpu(distributed_context_single_node_gloo): device = "cpu" _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.distributed @@ -391,6 +416,7 @@ def test_distrib_hvd(gloo_hvd_executor): gloo_hvd_executor(_test_distrib_on_output, (device,), np=nproc, do_init=True) gloo_hvd_executor(_test_distrib_on_metric, (device,), np=nproc, do_init=True) + gloo_hvd_executor(_test_distrib_accumulator_device, (device,), np=nproc, do_init=True) @pytest.mark.multinode_distributed @@ -400,6 +426,7 @@ def test_multinode_distrib_cpu(distributed_context_multi_node_gloo): device = "cpu" _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.multinode_distributed @@ -409,6 +436,7 @@ def test_multinode_distrib_gpu(distributed_context_multi_node_nccl): device = "cuda:{}".format(distributed_context_multi_node_nccl["local_rank"]) _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu @@ -418,12 +446,14 @@ def test_distrib_single_device_xla(): device = idist.device() _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_on_output(device) _test_distrib_on_metric(device) + _test_distrib_accumulator_device(device) @pytest.mark.tpu diff --git a/tests/ignite/metrics/test_top_k_categorical_accuracy.py b/tests/ignite/metrics/test_top_k_categorical_accuracy.py index 0f22d2f6b697..b16cd4efce3a 100644 --- a/tests/ignite/metrics/test_top_k_categorical_accuracy.py +++ b/tests/ignite/metrics/test_top_k_categorical_accuracy.py @@ -59,7 +59,7 @@ def _test_distrib_integration(device): rank = idist.get_rank() torch.manual_seed(12) - def _test(n_epochs): + def _test(n_epochs, metric_device): n_iters = 100 s = 16 n_classes = 10 @@ -79,7 +79,7 @@ def update(engine, i): engine = Engine(update) k = 5 - acc = TopKCategoricalAccuracy(k=k, device=device) + acc = TopKCategoricalAccuracy(k=k, device=metric_device) acc.attach(engine, "acc") data = list(range(n_iters)) @@ -94,20 +94,29 @@ def update(engine, i): assert pytest.approx(res) == true_res - for _ in range(5): - _test(n_epochs=1) - _test(n_epochs=2) + for _ in range(3): + for metric_device in ["cpu", idist.device()]: + _test(n_epochs=1, metric_device=metric_device) + _test(n_epochs=2, metric_device=metric_device) def _test_distrib_accumulator_device(device): - device = torch.device(device) - acc = TopKCategoricalAccuracy(2, device=device) - assert acc._device == device - y_pred = torch.tensor([[0.2, 0.4, 0.6, 0.8], [0.8, 0.6, 0.4, 0.2]]) - y = torch.ones(2).long() - acc.update((y_pred, y)) - assert acc._num_correct.device == device + for metric_device in [torch.device("cpu"), idist.device()]: + + acc = TopKCategoricalAccuracy(2, device=metric_device) + assert acc._device == metric_device + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) + + y_pred = torch.tensor([[0.2, 0.4, 0.6, 0.8], [0.8, 0.6, 0.4, 0.2]]) + y = torch.ones(2).long() + acc.update((y_pred, y)) + + assert acc._num_correct.device == metric_device, "{}:{} vs {}:{}".format( + type(acc._num_correct.device), acc._num_correct.device, type(metric_device), metric_device + ) @pytest.mark.distributed