Skip to content

update docs and docstrings for updated metrics #1239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
9f7daa1
update accuracy to accumulate _num_correct in a tensor on the right d…
n2cholas Aug 7, 2020
a87f93d
update loss metric to accumulate _sum in a tensor on the right device
n2cholas Aug 7, 2020
30b2e19
update mae metric to accumulate in a tensor on the right device
n2cholas Aug 7, 2020
a3e237c
update mpd metric to accumulate in a tensor on the right device
n2cholas Aug 7, 2020
7100176
update mse metric to accumulate in a tensor on the right device
n2cholas Aug 7, 2020
3228a0a
update top k accuracy metric to accumulate in a tensor on the right …
n2cholas Aug 7, 2020
412551e
update precision and recall metrics to accumulate in tensors on the r…
n2cholas Aug 8, 2020
4c4a76c
.....
n2cholas Aug 8, 2020
b1e6956
black formatting
n2cholas Aug 8, 2020
b081e92
reverted run*.sh
n2cholas Aug 10, 2020
a343c35
change all metrics default device to cpu except running_average
n2cholas Aug 16, 2020
8548601
Update ignite/metrics/precision.py
n2cholas Aug 16, 2020
b84226b
remove Optional type from metric devices since default is cpu
n2cholas Aug 16, 2020
685c23b
add comment explaining lack of detach in accuracy metrics
n2cholas Aug 16, 2020
0b4337d
update docstrings and docs
n2cholas Aug 17, 2020
b2fa213
Update ignite/metrics/accumulation.py
n2cholas Aug 17, 2020
90e0e9a
Update ignite/metrics/accumulation.py
n2cholas Aug 17, 2020
6c1fda4
Update ignite/metrics/accumulation.py
n2cholas Aug 17, 2020
c510e10
Update ignite/metrics/accuracy.py
n2cholas Aug 17, 2020
d5d4854
Update ignite/metrics/fbeta.py
n2cholas Aug 17, 2020
39515b7
Update ignite/metrics/loss.py
n2cholas Aug 17, 2020
3c49871
Update ignite/metrics/metric.py
n2cholas Aug 17, 2020
6de10dd
Update ignite/metrics/precision.py
n2cholas Aug 17, 2020
eca0bc3
Update ignite/metrics/recall.py
n2cholas Aug 17, 2020
ad7082e
add comment explaining lack of detach in metrics docs
n2cholas Aug 17, 2020
c057d52
Merge remote-tracking branch 'pytorch-ignite/metrics_impl' into metri…
n2cholas Aug 17, 2020
90b5b85
support device argument for running_average
n2cholas Aug 17, 2020
3481da1
update support for device argumenet for accumulation
n2cholas Aug 18, 2020
d340bb7
fix and improve device tests for metrics
n2cholas Aug 18, 2020
4824e24
fix and improve device tests for metrics
n2cholas Aug 18, 2020
1361866
fix TPU tests
n2cholas Aug 18, 2020
556262b
Apply suggestions from code review
vfdev-5 Aug 18, 2020
489620b
Apply suggestions from code review
vfdev-5 Aug 18, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions docs/source/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,15 +120,15 @@ specific condition (e.g. ignore user-defined classes):

class CustomAccuracy(Metric):

def __init__(self, ignored_class, output_transform=lambda x: x):
def __init__(self, ignored_class, output_transform=lambda x: x, device="cpu"):
self.ignored_class = ignored_class
self._num_correct = None
self._num_examples = None
super(CustomAccuracy, self).__init__(output_transform=output_transform)
super(CustomAccuracy, self).__init__(output_transform=output_transform, device=device)

@reinit__is_reduced
def reset(self):
self._num_correct = 0
self._num_correct = torch.tensor(0, device=self._device)
self._num_examples = 0
super(CustomAccuracy, self).reset()

Expand All @@ -144,21 +144,28 @@ specific condition (e.g. ignore user-defined classes):
indices = indices[mask]
correct = torch.eq(indices, y).view(-1)

self._num_correct += torch.sum(correct).item()
# We must detach tensors before adding them to the internal variables. In this case, torch.eq is not
# differentiable, so the computation graph is detached implicitly
self._num_correct += torch.sum(correct).to(self._device)
self._num_examples += correct.shape[0]

@sync_all_reduce("_num_examples", "_num_correct")
def compute(self):
if self._num_examples == 0:
raise NotComputableError('CustomAccuracy must have at least one example before it can be computed.')
return self._num_correct / self._num_examples
return self._num_correct.item() / self._num_examples


We imported necessary classes as :class:`~ignite.metrics.Metric`, :class:`~ignite.exceptions.NotComputableError` and
decorators to adapt the metric for distributed setting. In ``reset`` method, we reset internal variables ``_num_correct``
and ``_num_examples`` which are used to compute the custom metric. In ``updated`` method we define how to update
the internal variables. And finally in ``compute`` method, we compute metric value.

Notice that ``_num_correct`` is a tensor, since in ``update`` we accumulate tensor values. ``_num_examples`` is a python
scalar since we accumulate normal integers. For differentiable metrics, you must detach the accumulated values before
adding them to the internal variables. Accuracy is not differentiable (specifically the ``torch.eq`` call), so it
is implicitly detached from the computation graph.

We can check this implementation in a simple case:

.. code-block:: python
Expand Down
16 changes: 12 additions & 4 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ class VariableAccumulation(Metric):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

"""

Expand Down Expand Up @@ -70,6 +72,9 @@ def update(self, output: Union[Any, torch.Tensor, numbers.Number]) -> None:
output = output.to(self._device)

self.accumulator = self._op(self.accumulator, output)
if isinstance(self.accumulator, torch.Tensor):
self.accumulator = self.accumulator.to(self._device)

if hasattr(output, "shape"):
self.num_examples += output.shape[0] if len(output.shape) > 1 else 1
else:
Expand Down Expand Up @@ -114,8 +119,9 @@ class Average(VariableAccumulation):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.

device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.
"""

def __init__(
Expand Down Expand Up @@ -160,7 +166,9 @@ class GeometricAverage(VariableAccumulation):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

"""

Expand Down
4 changes: 3 additions & 1 deletion ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def thresholded_output_transform(output):
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
is_multilabel (bool, optional): flag to use in multilabel case. By default, False.
device (str of torch.device, optional): unused argument.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

"""

Expand Down
4 changes: 3 additions & 1 deletion ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ class ConfusionMatrix(Metric):
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

Note:
In case of the targets `y` in `(batch_size, ...)` format, target indices between 0 and `num_classes` only
Expand Down
4 changes: 3 additions & 1 deletion ignite/metrics/fbeta.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@ def Fbeta(
output_transform (callable, optional): a callable that is used to transform the
:class:`~ignite.engine.engine.Engine`'s ``process_function``'s output into the
form expected by the metric. It is used only if precision or recall are not provided.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

Returns:
MetricsLambda, F-beta metric
Expand Down
4 changes: 3 additions & 1 deletion ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ class Loss(Metric):
keywords arguments. If extra keywords arguments are provided they are passed to `loss_fn`.
batch_size (callable): a callable taking a target tensor that returns the
first dimension size (usually the batch size).
device (str of torch.device, optional): unused argument.
device (str or torch.device): specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

"""

Expand Down
4 changes: 3 additions & 1 deletion ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ class Metric(metaclass=ABCMeta):
form expected by the metric. This can be useful if, for example, you have a multi-output model and
you want to compute the metric with respect to one of the outputs.
By default, metrics require the output as ``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``.
device (str of torch.device, optional): optional device specification for internal storage.
device (str or torch.device): specifies which device updates are accumulated on. Setting the
metric's device to be the same as your ``update`` arguments ensures the ``update`` method is
non-blocking. By default, CPU.

"""

Expand Down
4 changes: 3 additions & 1 deletion ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def thresholded_output_transform(output):
in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).
is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average
parameter should be True and the average is computed across samples, instead of classes.
device (str of torch.device, optional): unused argument.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

"""

Expand Down
4 changes: 3 additions & 1 deletion ignite/metrics/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@ def thresholded_output_transform(output):
in multiclass case), otherwise, returns a tensor with the precision (for each class in multiclass case).
is_multilabel (bool, optional) flag to use in multilabel case. By default, value is False. If True, average
parameter should be True and the average is computed across samples, instead of classes.
device (str of torch.device, optional): unused argument.
device (str or torch.device): specifies which device updates are accumulated on. Setting the metric's
device to be the same as your ``update`` arguments ensures the ``update`` method is non-blocking. By
default, CPU.

"""

Expand Down
11 changes: 9 additions & 2 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ class RunningAverage(Metric):
corresponds the output of process function. Otherwise it should be None.
epoch_bound (boolean, optional): whether the running average should be reset after each epoch (defaults
to True).
device (str of torch.device, optional): unused argument.
device (str or torch.device, optional): specifies which device updates are accumulated on. Should be
None when ``src`` is an instance of :class:`~ignite.metrics.Metric`, as the running average will
use the ``src``'s device. Otherwise, defaults to CPU. Only applicable when the computed value
from the metric is a tensor.


Examples:

Expand Down Expand Up @@ -63,6 +67,7 @@ def __init__(
self.src = src
self._get_src_value = self._get_metric_value
self.iteration_completed = self._metric_iteration_completed
device = src._device
else:
if output_transform is None:
raise ValueError(
Expand All @@ -71,6 +76,8 @@ def __init__(
)
self._get_src_value = self._get_output_value
self.update = self._output_update
if device is None:
device = torch.device("cpu")

self.alpha = alpha
self.epoch_bound = epoch_bound
Expand Down Expand Up @@ -118,5 +125,5 @@ def _metric_iteration_completed(self, engine: Engine) -> None:
@reinit__is_reduced
def _output_update(self, output: Union[torch.Tensor, float]) -> None:
if isinstance(output, torch.Tensor):
output = output.detach().clone()
output = output.detach().to(self._device, copy=True)
self.src = output
Loading