Skip to content

Commit 0c9d1e4

Browse files
vfdev-5sdesrozis
andauthored
Minor optimization for idist.get_* (#1196)
* Minor optimization for idist.get_* * Set overhead threshold to 1.9 * Keep only test_idist_methods_overhead_nccl * Removed _sync_model_wrapper to implicitly check if we need to sync model This also reduces time of idist.get_* method calls vs native calls * Update test_native.py * autopep8 fix * Update test_native.py Co-authored-by: AutoPEP8 <> Co-authored-by: Sylvain Desroziers <[email protected]>
1 parent b4e81fe commit 0c9d1e4

File tree

2 files changed

+89
-60
lines changed

2 files changed

+89
-60
lines changed

ignite/distributed/utils.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from typing import Callable, List, Mapping, Optional, Tuple, Union
55

66
import torch
7-
import torch.distributed as dist
87

98
from ignite.distributed.comp_models import (
109
_SerialModel,
@@ -46,19 +45,13 @@
4645
_need_to_sync = True
4746

4847

49-
def _sync_model_wrapper(func):
50-
@wraps(func)
51-
def wrapper(*args, **kwargs):
52-
if isinstance(_model, _SerialModel) and _need_to_sync:
53-
sync()
54-
return func(*args, **kwargs)
55-
56-
return wrapper
57-
58-
59-
def sync():
48+
def sync(temporary=False):
6049
"""Helper method to force this module to synchronize with current distributed context.
6150
This method should be used when distributed context is manually created or destroyed.
51+
52+
Args:
53+
temporary (bool): If True, distributed model synchronization is done every call of ``idist.get_*`` methods.
54+
This may have performance negative impact.
6255
"""
6356
global _model
6457

@@ -67,13 +60,12 @@ def sync():
6760
continue
6861
model = comp_model_cls.create_from_context()
6962
if model is not None:
70-
_model = model
63+
_set_model(model, temporary=temporary)
7164
return
7265

7366
_model = _SerialModel()
7467

7568

76-
@_sync_model_wrapper
7769
def device() -> torch.device:
7870
"""Returns current device according to current distributed configuration.
7971
@@ -84,10 +76,12 @@ def device() -> torch.device:
8476
Returns:
8577
torch.device
8678
"""
79+
if _need_to_sync and isinstance(_model, _SerialModel):
80+
sync(temporary=True)
81+
8782
return _model.device()
8883

8984

90-
@_sync_model_wrapper
9185
def backend() -> Optional[str]:
9286
"""Returns computation model's backend.
9387
@@ -98,6 +92,9 @@ def backend() -> Optional[str]:
9892
Returns:
9993
str or None
10094
"""
95+
if _need_to_sync and isinstance(_model, _SerialModel):
96+
sync(temporary=True)
97+
10198
return _model.backend()
10299

103100

@@ -110,7 +107,6 @@ def available_backends() -> Tuple[str]:
110107
return out
111108

112109

113-
@_sync_model_wrapper
114110
def model_name() -> str:
115111
"""Returns distributed configuration name (given by ignite)
116112
@@ -119,51 +115,66 @@ def model_name() -> str:
119115
- `xla-dist` for XLA distributed configuration
120116
121117
"""
118+
if _need_to_sync and isinstance(_model, _SerialModel):
119+
sync(temporary=True)
120+
122121
return _model.name
123122

124123

125-
@_sync_model_wrapper
126124
def get_world_size() -> int:
127125
"""Returns world size of current distributed configuration. Returns 1 if no distributed configuration.
128126
"""
127+
if _need_to_sync and isinstance(_model, _SerialModel):
128+
sync(temporary=True)
129+
129130
return _model.get_world_size()
130131

131132

132-
@_sync_model_wrapper
133133
def get_rank() -> int:
134134
"""Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.
135135
"""
136+
if _need_to_sync and isinstance(_model, _SerialModel):
137+
sync(temporary=True)
138+
136139
return _model.get_rank()
137140

138141

139-
@_sync_model_wrapper
140142
def get_local_rank() -> int:
141143
"""Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.
142144
"""
145+
if _need_to_sync and isinstance(_model, _SerialModel):
146+
sync(temporary=True)
147+
143148
return _model.get_local_rank()
144149

145150

146-
@_sync_model_wrapper
147151
def get_nproc_per_node() -> int:
148152
"""Returns number of processes (or tasks) per node within current distributed configuration.
149153
Returns 1 if no distributed configuration.
150154
"""
155+
if _need_to_sync and isinstance(_model, _SerialModel):
156+
sync(temporary=True)
157+
151158
return _model.get_nproc_per_node()
152159

153160

154-
@_sync_model_wrapper
155161
def get_nnodes() -> int:
156162
"""Returns number of nodes within current distributed configuration.
157163
Returns 1 if no distributed configuration.
158164
"""
165+
if _need_to_sync and isinstance(_model, _SerialModel):
166+
sync(temporary=True)
167+
159168
return _model.get_nnodes()
160169

161170

162-
@_sync_model_wrapper
163171
def get_node_rank() -> int:
164172
"""Returns node rank within current distributed configuration.
165173
Returns 0 if no distributed configuration.
166174
"""
175+
if _need_to_sync and isinstance(_model, _SerialModel):
176+
sync(temporary=True)
177+
167178
return _model.get_node_rank()
168179

169180

@@ -291,7 +302,6 @@ def train_fn(local_rank, a, b, c, d=12):
291302
)
292303

293304

294-
@_sync_model_wrapper
295305
def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[torch.Tensor, Number]:
296306
"""Helper method to perform all reduce operation.
297307
@@ -303,10 +313,12 @@ def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[to
303313
torch.Tensor or number
304314
305315
"""
316+
if _need_to_sync and isinstance(_model, _SerialModel):
317+
sync(temporary=True)
318+
306319
return _model.all_reduce(tensor, op)
307320

308321

309-
@_sync_model_wrapper
310322
def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]:
311323
"""Helper method to perform all gather operation.
312324
@@ -318,13 +330,18 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor,
318330
List of strings
319331
320332
"""
333+
if _need_to_sync and isinstance(_model, _SerialModel):
334+
sync(temporary=True)
335+
321336
return _model.all_gather(tensor)
322337

323338

324-
@_sync_model_wrapper
325339
def barrier():
326340
"""Helper method to synchronize all processes.
327341
"""
342+
if _need_to_sync and isinstance(_model, _SerialModel):
343+
sync(temporary=True)
344+
328345
_model.barrier()
329346

330347

@@ -356,11 +373,11 @@ def run(local_rank, *args, **kwargs):
356373
ComputationModel._ext_local_rank = index
357374

358375

359-
def _set_model(model):
376+
def _set_model(model, temporary=False):
360377
global _model, _need_to_sync
361378
_model = model
362379
_need_to_sync = True
363-
if not isinstance(_model, _SerialModel):
380+
if not isinstance(_model, _SerialModel) and not temporary:
364381
_need_to_sync = False
365382

366383

@@ -408,7 +425,7 @@ def train_fn(local_rank, a, b, c):
408425
409426
410427
"""
411-
if not (has_xla_support or dist.is_available()):
428+
if not (has_xla_support or has_native_dist_support):
412429
# nothing to do => serial model
413430
# maybe warn about this
414431
return

tests/ignite/distributed/utils/test_native.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -202,51 +202,63 @@ def test_idist_barrier_gloo(distributed_context_single_node_gloo):
202202
_test_distrib_barrier(device)
203203

204204

205+
def _test_idist_methods_overhead(ok_factor):
206+
import time
207+
208+
n = 100000
209+
m = 5
210+
211+
t2 = 0.0
212+
t1 = 0.0
213+
for j in range(m):
214+
start = time.time()
215+
for _ in range(n):
216+
_ = dist.get_world_size()
217+
_ = dist.get_rank()
218+
elapsed = time.time() - start
219+
t2 += elapsed / n / m
220+
221+
start = time.time()
222+
for _ in range(n):
223+
_ = idist.get_world_size()
224+
_ = idist.get_rank()
225+
elapsed = time.time() - start
226+
t1 += elapsed / n / m
227+
228+
overhead_factor = t1 / t2
229+
assert overhead_factor < ok_factor, "{} vs {} | {} vs {}".format(overhead_factor, ok_factor, t2, t1)
230+
231+
205232
@pytest.mark.distributed
206233
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
234+
@pytest.mark.skipif(
235+
not torch.cuda.is_available(), reason="Do not want to run this test on Github or Travis, but CircleCI"
236+
)
207237
def test_idist_methods_overhead_gloo(distributed_context_single_node_gloo):
208-
import time
238+
_test_idist_methods_overhead(2.5)
209239

210-
n = 100000
211-
start = time.time()
212-
for _ in range(n):
213-
_ = idist.get_world_size()
214-
_ = idist.get_rank()
215-
elapsed = time.time() - start
216-
t1 = elapsed / n
240+
idist.sync()
241+
from ignite.distributed.utils import _model
242+
from ignite.distributed.comp_models.native import _NativeDistModel
217243

218-
start = time.time()
219-
for _ in range(n):
220-
_ = dist.get_world_size()
221-
_ = idist.get_rank()
222-
elapsed = time.time() - start
223-
t2 = elapsed / n
244+
assert isinstance(_model, _NativeDistModel)
224245

225-
assert t2 * 6 > t1, "{} * 6 vs {}".format(t2, t1)
246+
_test_idist_methods_overhead(1.7)
226247

227248

228249
@pytest.mark.distributed
229250
@pytest.mark.skipif(not has_native_dist_support, reason="Skip if no native dist support")
230251
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
231252
def test_idist_methods_overhead_nccl(distributed_context_single_node_nccl):
232-
import time
253+
_test_idist_methods_overhead(2.5)
233254

234-
n = 100000
235-
start = time.time()
236-
for _ in range(n):
237-
_ = idist.get_world_size()
238-
_ = idist.get_rank()
239-
elapsed = time.time() - start
240-
t1 = elapsed / n
241-
242-
start = time.time()
243-
for _ in range(n):
244-
_ = dist.get_world_size()
245-
_ = idist.get_rank()
246-
elapsed = time.time() - start
247-
t2 = elapsed / n
248-
249-
assert t2 * 3 > t1, "{} * 3 vs {}".format(t2, t1)
255+
idist.sync()
256+
from ignite.distributed.utils import _model
257+
from ignite.distributed.comp_models.native import _NativeDistModel
258+
259+
assert isinstance(_model, _NativeDistModel)
260+
261+
_test_idist_methods_overhead(1.7)
250262

251263

252264
@pytest.mark.distributed

0 commit comments

Comments
 (0)