Skip to content

Minor optimization for idist.get_* #1196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions ignite/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Callable, List, Mapping, Optional, Tuple, Union

import torch
import torch.distributed as dist

from ignite.distributed.comp_models import (
_SerialModel,
Expand Down Expand Up @@ -46,19 +45,13 @@
_need_to_sync = True


def _sync_model_wrapper(func):
@wraps(func)
def wrapper(*args, **kwargs):
if isinstance(_model, _SerialModel) and _need_to_sync:
sync()
return func(*args, **kwargs)

return wrapper


def sync():
def sync(temporary=False):
"""Helper method to force this module to synchronize with current distributed context.
This method should be used when distributed context is manually created or destroyed.

Args:
temporary (bool): If True, distributed model synchronization is done every call of ``idist.get_*`` methods.
This may have performance negative impact.
"""
global _model

Expand All @@ -67,13 +60,12 @@ def sync():
continue
model = comp_model_cls.create_from_context()
if model is not None:
_model = model
_set_model(model, temporary=temporary)
return

_model = _SerialModel()


@_sync_model_wrapper
def device() -> torch.device:
"""Returns current device according to current distributed configuration.

Expand All @@ -84,10 +76,12 @@ def device() -> torch.device:
Returns:
torch.device
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.device()


@_sync_model_wrapper
def backend() -> Optional[str]:
"""Returns computation model's backend.

Expand All @@ -98,6 +92,9 @@ def backend() -> Optional[str]:
Returns:
str or None
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.backend()


Expand All @@ -110,7 +107,6 @@ def available_backends() -> Tuple[str]:
return out


@_sync_model_wrapper
def model_name() -> str:
"""Returns distributed configuration name (given by ignite)

Expand All @@ -119,51 +115,66 @@ def model_name() -> str:
- `xla-dist` for XLA distributed configuration

"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.name


@_sync_model_wrapper
def get_world_size() -> int:
"""Returns world size of current distributed configuration. Returns 1 if no distributed configuration.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.get_world_size()


@_sync_model_wrapper
def get_rank() -> int:
"""Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.get_rank()


@_sync_model_wrapper
def get_local_rank() -> int:
"""Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.get_local_rank()


@_sync_model_wrapper
def get_nproc_per_node() -> int:
"""Returns number of processes (or tasks) per node within current distributed configuration.
Returns 1 if no distributed configuration.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.get_nproc_per_node()


@_sync_model_wrapper
def get_nnodes() -> int:
"""Returns number of nodes within current distributed configuration.
Returns 1 if no distributed configuration.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.get_nnodes()


@_sync_model_wrapper
def get_node_rank() -> int:
"""Returns node rank within current distributed configuration.
Returns 0 if no distributed configuration.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.get_node_rank()


Expand Down Expand Up @@ -291,7 +302,6 @@ def train_fn(local_rank, a, b, c, d=12):
)


@_sync_model_wrapper
def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[torch.Tensor, Number]:
"""Helper method to perform all reduce operation.

Expand All @@ -303,10 +313,12 @@ def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[to
torch.Tensor or number

"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.all_reduce(tensor, op)


@_sync_model_wrapper
def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]:
"""Helper method to perform all gather operation.

Expand All @@ -318,13 +330,18 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor,
List of strings

"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

return _model.all_gather(tensor)


@_sync_model_wrapper
def barrier():
"""Helper method to synchronize all processes.
"""
if _need_to_sync and isinstance(_model, _SerialModel):
sync(temporary=True)

_model.barrier()


Expand Down Expand Up @@ -356,11 +373,11 @@ def run(local_rank, *args, **kwargs):
ComputationModel._ext_local_rank = index


def _set_model(model):
def _set_model(model, temporary=False):
global _model, _need_to_sync
_model = model
_need_to_sync = True
if not isinstance(_model, _SerialModel):
if not isinstance(_model, _SerialModel) and not temporary:
_need_to_sync = False


Expand Down Expand Up @@ -408,7 +425,7 @@ def train_fn(local_rank, a, b, c):


"""
if not (has_xla_support or dist.is_available()):
if not (has_xla_support or has_native_dist_support):
# nothing to do => serial model
# maybe warn about this
return
Expand Down
76 changes: 44 additions & 32 deletions tests/ignite/distributed/utils/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,51 +202,63 @@ def test_idist_barrier_gloo(distributed_context_single_node_gloo):
_test_distrib_barrier(device)


def _test_idist_methods_overhead(ok_factor):
import time

n = 100000
m = 5

t2 = 0.0
t1 = 0.0
for j in range(m):
start = time.time()
for _ in range(n):
_ = dist.get_world_size()
_ = dist.get_rank()
elapsed = time.time() - start
t2 += elapsed / n / m

start = time.time()
for _ in range(n):
_ = idist.get_world_size()
_ = idist.get_rank()
elapsed = time.time() - start
t1 += elapsed / n / m

overhead_factor = t1 / t2
assert overhead_factor < ok_factor, "{} vs {} | {} vs {}".format(overhead_factor, ok_factor, t2, t1)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="Do not want to run this test on Github or Travis, but CircleCI"
)
def test_idist_methods_overhead_gloo(distributed_context_single_node_gloo):
import time
_test_idist_methods_overhead(2.5)

n = 100000
start = time.time()
for _ in range(n):
_ = idist.get_world_size()
_ = idist.get_rank()
elapsed = time.time() - start
t1 = elapsed / n
idist.sync()
from ignite.distributed.utils import _model
from ignite.distributed.comp_models.native import _NativeDistModel

start = time.time()
for _ in range(n):
_ = dist.get_world_size()
_ = idist.get_rank()
elapsed = time.time() - start
t2 = elapsed / n
assert isinstance(_model, _NativeDistModel)

assert t2 * 6 > t1, "{} * 6 vs {}".format(t2, t1)
_test_idist_methods_overhead(1.7)


@pytest.mark.distributed
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test_idist_methods_overhead_nccl(distributed_context_single_node_nccl):
import time
_test_idist_methods_overhead(2.5)

n = 100000
start = time.time()
for _ in range(n):
_ = idist.get_world_size()
_ = idist.get_rank()
elapsed = time.time() - start
t1 = elapsed / n

start = time.time()
for _ in range(n):
_ = dist.get_world_size()
_ = idist.get_rank()
elapsed = time.time() - start
t2 = elapsed / n

assert t2 * 3 > t1, "{} * 3 vs {}".format(t2, t1)
idist.sync()
from ignite.distributed.utils import _model
from ignite.distributed.comp_models.native import _NativeDistModel

assert isinstance(_model, _NativeDistModel)

_test_idist_methods_overhead(1.7)


@pytest.mark.distributed
Expand Down