diff --git a/ignite/distributed/utils.py b/ignite/distributed/utils.py index bf8dba0185e1..d147226ca0c3 100644 --- a/ignite/distributed/utils.py +++ b/ignite/distributed/utils.py @@ -4,7 +4,6 @@ from typing import Callable, List, Mapping, Optional, Tuple, Union import torch -import torch.distributed as dist from ignite.distributed.comp_models import ( _SerialModel, @@ -46,19 +45,13 @@ _need_to_sync = True -def _sync_model_wrapper(func): - @wraps(func) - def wrapper(*args, **kwargs): - if isinstance(_model, _SerialModel) and _need_to_sync: - sync() - return func(*args, **kwargs) - - return wrapper - - -def sync(): +def sync(temporary=False): """Helper method to force this module to synchronize with current distributed context. This method should be used when distributed context is manually created or destroyed. + + Args: + temporary (bool): If True, distributed model synchronization is done every call of ``idist.get_*`` methods. + This may have performance negative impact. """ global _model @@ -67,13 +60,12 @@ def sync(): continue model = comp_model_cls.create_from_context() if model is not None: - _model = model + _set_model(model, temporary=temporary) return _model = _SerialModel() -@_sync_model_wrapper def device() -> torch.device: """Returns current device according to current distributed configuration. @@ -84,10 +76,12 @@ def device() -> torch.device: Returns: torch.device """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.device() -@_sync_model_wrapper def backend() -> Optional[str]: """Returns computation model's backend. @@ -98,6 +92,9 @@ def backend() -> Optional[str]: Returns: str or None """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.backend() @@ -110,7 +107,6 @@ def available_backends() -> Tuple[str]: return out -@_sync_model_wrapper def model_name() -> str: """Returns distributed configuration name (given by ignite) @@ -119,51 +115,66 @@ def model_name() -> str: - `xla-dist` for XLA distributed configuration """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.name -@_sync_model_wrapper def get_world_size() -> int: """Returns world size of current distributed configuration. Returns 1 if no distributed configuration. """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.get_world_size() -@_sync_model_wrapper def get_rank() -> int: """Returns process rank within current distributed configuration. Returns 0 if no distributed configuration. """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.get_rank() -@_sync_model_wrapper def get_local_rank() -> int: """Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration. """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.get_local_rank() -@_sync_model_wrapper def get_nproc_per_node() -> int: """Returns number of processes (or tasks) per node within current distributed configuration. Returns 1 if no distributed configuration. """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.get_nproc_per_node() -@_sync_model_wrapper def get_nnodes() -> int: """Returns number of nodes within current distributed configuration. Returns 1 if no distributed configuration. """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.get_nnodes() -@_sync_model_wrapper def get_node_rank() -> int: """Returns node rank within current distributed configuration. Returns 0 if no distributed configuration. """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.get_node_rank() @@ -291,7 +302,6 @@ def train_fn(local_rank, a, b, c, d=12): ) -@_sync_model_wrapper def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[torch.Tensor, Number]: """Helper method to perform all reduce operation. @@ -303,10 +313,12 @@ def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[to torch.Tensor or number """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.all_reduce(tensor, op) -@_sync_model_wrapper def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]: """Helper method to perform all gather operation. @@ -318,13 +330,18 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, List of strings """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + return _model.all_gather(tensor) -@_sync_model_wrapper def barrier(): """Helper method to synchronize all processes. """ + if _need_to_sync and isinstance(_model, _SerialModel): + sync(temporary=True) + _model.barrier() @@ -356,11 +373,11 @@ def run(local_rank, *args, **kwargs): ComputationModel._ext_local_rank = index -def _set_model(model): +def _set_model(model, temporary=False): global _model, _need_to_sync _model = model _need_to_sync = True - if not isinstance(_model, _SerialModel): + if not isinstance(_model, _SerialModel) and not temporary: _need_to_sync = False @@ -408,7 +425,7 @@ def train_fn(local_rank, a, b, c): """ - if not (has_xla_support or dist.is_available()): + if not (has_xla_support or has_native_dist_support): # nothing to do => serial model # maybe warn about this return diff --git a/tests/ignite/distributed/utils/test_native.py b/tests/ignite/distributed/utils/test_native.py index 50eb802027c4..ebd87778b6d7 100644 --- a/tests/ignite/distributed/utils/test_native.py +++ b/tests/ignite/distributed/utils/test_native.py @@ -202,51 +202,63 @@ def test_idist_barrier_gloo(distributed_context_single_node_gloo): _test_distrib_barrier(device) +def _test_idist_methods_overhead(ok_factor): + import time + + n = 100000 + m = 5 + + t2 = 0.0 + t1 = 0.0 + for j in range(m): + start = time.time() + for _ in range(n): + _ = dist.get_world_size() + _ = dist.get_rank() + elapsed = time.time() - start + t2 += elapsed / n / m + + start = time.time() + for _ in range(n): + _ = idist.get_world_size() + _ = idist.get_rank() + elapsed = time.time() - start + t1 += elapsed / n / m + + overhead_factor = t1 / t2 + assert overhead_factor < ok_factor, "{} vs {} | {} vs {}".format(overhead_factor, ok_factor, t2, t1) + + @pytest.mark.distributed @pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") +@pytest.mark.skipif( + not torch.cuda.is_available(), reason="Do not want to run this test on Github or Travis, but CircleCI" +) def test_idist_methods_overhead_gloo(distributed_context_single_node_gloo): - import time + _test_idist_methods_overhead(2.5) - n = 100000 - start = time.time() - for _ in range(n): - _ = idist.get_world_size() - _ = idist.get_rank() - elapsed = time.time() - start - t1 = elapsed / n + idist.sync() + from ignite.distributed.utils import _model + from ignite.distributed.comp_models.native import _NativeDistModel - start = time.time() - for _ in range(n): - _ = dist.get_world_size() - _ = idist.get_rank() - elapsed = time.time() - start - t2 = elapsed / n + assert isinstance(_model, _NativeDistModel) - assert t2 * 6 > t1, "{} * 6 vs {}".format(t2, t1) + _test_idist_methods_overhead(1.7) @pytest.mark.distributed @pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support") @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU") def test_idist_methods_overhead_nccl(distributed_context_single_node_nccl): - import time + _test_idist_methods_overhead(2.5) - n = 100000 - start = time.time() - for _ in range(n): - _ = idist.get_world_size() - _ = idist.get_rank() - elapsed = time.time() - start - t1 = elapsed / n - - start = time.time() - for _ in range(n): - _ = dist.get_world_size() - _ = idist.get_rank() - elapsed = time.time() - start - t2 = elapsed / n - - assert t2 * 3 > t1, "{} * 3 vs {}".format(t2, t1) + idist.sync() + from ignite.distributed.utils import _model + from ignite.distributed.comp_models.native import _NativeDistModel + + assert isinstance(_model, _NativeDistModel) + + _test_idist_methods_overhead(1.7) @pytest.mark.distributed