Skip to content

Customization of reduction/gathering ops for Metric #1288

Open
@vfdev-5

Description

@vfdev-5

🚀 Feature

Idea is to make configurable Metric's reduction/gathering ops. By default, we are using our code, but user can globally override those functions. For example, if uses a custom unsupported distributed framework, or deals with asymmetry like here etc

EDIT:

When a metric is implemented methods like reset, update and compute are decorated with reinit__is_reduced and sync_all_reduce.
sync_all_reduce is implemented here:

def sync_all_reduce(*attrs: Any) -> Callable:
"""Helper decorator for distributed configuration to collect instance attribute value
across all participating processes and apply the specified reduction operation.
See :doc:`metrics` on how to use it.
Args:
attrs: attribute names of decorated class
.. versionchanged:: 0.4.5
- Ability to handle different reduction operations (SUM, MAX, MIN, PRODUCT).
"""
def wrapper(func: Callable) -> Callable:
@wraps(func)
def another_wrapper(self: Metric, *args: Any, **kwargs: Any) -> Callable:
if not isinstance(self, Metric):
raise RuntimeError(
"Decorator sync_all_reduce should be used on ignite.metric.Metric class methods only"
)
ws = idist.get_world_size()
if len(attrs) > 0 and not self._is_reduced:
if ws > 1:
for attr in attrs:
op_kwargs = {}
if ":" in attr:
attr, op = attr.split(":")
valid_ops = ["MIN", "MAX", "SUM", "PRODUCT"]
if op not in valid_ops:
raise ValueError(f"Reduction operation is not valid (expected : {valid_ops}, got: {op}")
op_kwargs["op"] = op
t = getattr(self, attr, None)
if t is not None:
t = idist.all_reduce(t, **op_kwargs)
self._is_reduced = True
setattr(self, attr, t)
else:
self._is_reduced = True
return func(self, *args, **kwargs)
return another_wrapper
setattr(wrapper, "_decorated", True)
return wrapper

where we are using idist.all_reduce(t, **op_kwargs)
So, the issue desctiption says:

Idea is to make configurable Metric's reduction/gathering ops. By default, we are using our code, but user can globally override those functions.

In other words, we would like to be able to call user custom all_reduce instead of idist.all_reduce


A tentative API for this feature

import ignite.distributed as idist
from ignite.metrics import set_all_reduce_fn, reset_all_reduce_fn, get_all_reduce_fn
from ignite.metrics import Accuracy

def my_all_reduce(tensor: Union[torch.Tensor, float], op: str = "SUM", **kwargs):
    # ... custom implementation
    pass


set_all_reduce_fn(my_all_reduce)
assert get_all_reduce_fn() == my_all_reduce

acc = Accuracy()
acc.update(...)
value = acc.compute()  # should call my_all_reduce

reset_all_reduce_fn()
assert get_all_reduce_fn() == idist.all_reduce

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions