4
4
from typing import Callable , List , Mapping , Optional , Tuple , Union
5
5
6
6
import torch
7
- import torch .distributed as dist
8
7
9
8
from ignite .distributed .comp_models import (
10
9
_SerialModel ,
46
45
_need_to_sync = True
47
46
48
47
49
- def _sync_model_wrapper (func ):
50
- @wraps (func )
51
- def wrapper (* args , ** kwargs ):
52
- if isinstance (_model , _SerialModel ) and _need_to_sync :
53
- sync ()
54
- return func (* args , ** kwargs )
55
-
56
- return wrapper
57
-
58
-
59
- def sync ():
48
+ def sync (temporary = False ):
60
49
"""Helper method to force this module to synchronize with current distributed context.
61
50
This method should be used when distributed context is manually created or destroyed.
51
+
52
+ Args:
53
+ temporary (bool): If True, distributed model synchronization is done every call of ``idist.get_*`` methods.
54
+ This may have performance negative impact.
62
55
"""
63
56
global _model
64
57
@@ -67,13 +60,12 @@ def sync():
67
60
continue
68
61
model = comp_model_cls .create_from_context ()
69
62
if model is not None :
70
- _model = model
63
+ _set_model ( model , temporary = temporary )
71
64
return
72
65
73
66
_model = _SerialModel ()
74
67
75
68
76
- @_sync_model_wrapper
77
69
def device () -> torch .device :
78
70
"""Returns current device according to current distributed configuration.
79
71
@@ -84,10 +76,12 @@ def device() -> torch.device:
84
76
Returns:
85
77
torch.device
86
78
"""
79
+ if _need_to_sync and isinstance (_model , _SerialModel ):
80
+ sync (temporary = True )
81
+
87
82
return _model .device ()
88
83
89
84
90
- @_sync_model_wrapper
91
85
def backend () -> Optional [str ]:
92
86
"""Returns computation model's backend.
93
87
@@ -98,6 +92,9 @@ def backend() -> Optional[str]:
98
92
Returns:
99
93
str or None
100
94
"""
95
+ if _need_to_sync and isinstance (_model , _SerialModel ):
96
+ sync (temporary = True )
97
+
101
98
return _model .backend ()
102
99
103
100
@@ -110,7 +107,6 @@ def available_backends() -> Tuple[str]:
110
107
return out
111
108
112
109
113
- @_sync_model_wrapper
114
110
def model_name () -> str :
115
111
"""Returns distributed configuration name (given by ignite)
116
112
@@ -119,51 +115,66 @@ def model_name() -> str:
119
115
- `xla-dist` for XLA distributed configuration
120
116
121
117
"""
118
+ if _need_to_sync and isinstance (_model , _SerialModel ):
119
+ sync (temporary = True )
120
+
122
121
return _model .name
123
122
124
123
125
- @_sync_model_wrapper
126
124
def get_world_size () -> int :
127
125
"""Returns world size of current distributed configuration. Returns 1 if no distributed configuration.
128
126
"""
127
+ if _need_to_sync and isinstance (_model , _SerialModel ):
128
+ sync (temporary = True )
129
+
129
130
return _model .get_world_size ()
130
131
131
132
132
- @_sync_model_wrapper
133
133
def get_rank () -> int :
134
134
"""Returns process rank within current distributed configuration. Returns 0 if no distributed configuration.
135
135
"""
136
+ if _need_to_sync and isinstance (_model , _SerialModel ):
137
+ sync (temporary = True )
138
+
136
139
return _model .get_rank ()
137
140
138
141
139
- @_sync_model_wrapper
140
142
def get_local_rank () -> int :
141
143
"""Returns local process rank within current distributed configuration. Returns 0 if no distributed configuration.
142
144
"""
145
+ if _need_to_sync and isinstance (_model , _SerialModel ):
146
+ sync (temporary = True )
147
+
143
148
return _model .get_local_rank ()
144
149
145
150
146
- @_sync_model_wrapper
147
151
def get_nproc_per_node () -> int :
148
152
"""Returns number of processes (or tasks) per node within current distributed configuration.
149
153
Returns 1 if no distributed configuration.
150
154
"""
155
+ if _need_to_sync and isinstance (_model , _SerialModel ):
156
+ sync (temporary = True )
157
+
151
158
return _model .get_nproc_per_node ()
152
159
153
160
154
- @_sync_model_wrapper
155
161
def get_nnodes () -> int :
156
162
"""Returns number of nodes within current distributed configuration.
157
163
Returns 1 if no distributed configuration.
158
164
"""
165
+ if _need_to_sync and isinstance (_model , _SerialModel ):
166
+ sync (temporary = True )
167
+
159
168
return _model .get_nnodes ()
160
169
161
170
162
- @_sync_model_wrapper
163
171
def get_node_rank () -> int :
164
172
"""Returns node rank within current distributed configuration.
165
173
Returns 0 if no distributed configuration.
166
174
"""
175
+ if _need_to_sync and isinstance (_model , _SerialModel ):
176
+ sync (temporary = True )
177
+
167
178
return _model .get_node_rank ()
168
179
169
180
@@ -291,7 +302,6 @@ def train_fn(local_rank, a, b, c, d=12):
291
302
)
292
303
293
304
294
- @_sync_model_wrapper
295
305
def all_reduce (tensor : Union [torch .Tensor , Number ], op : str = "SUM" ) -> Union [torch .Tensor , Number ]:
296
306
"""Helper method to perform all reduce operation.
297
307
@@ -303,10 +313,12 @@ def all_reduce(tensor: Union[torch.Tensor, Number], op: str = "SUM") -> Union[to
303
313
torch.Tensor or number
304
314
305
315
"""
316
+ if _need_to_sync and isinstance (_model , _SerialModel ):
317
+ sync (temporary = True )
318
+
306
319
return _model .all_reduce (tensor , op )
307
320
308
321
309
- @_sync_model_wrapper
310
322
def all_gather (tensor : Union [torch .Tensor , Number , str ]) -> Union [torch .Tensor , Number , List [str ]]:
311
323
"""Helper method to perform all gather operation.
312
324
@@ -318,13 +330,18 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor,
318
330
List of strings
319
331
320
332
"""
333
+ if _need_to_sync and isinstance (_model , _SerialModel ):
334
+ sync (temporary = True )
335
+
321
336
return _model .all_gather (tensor )
322
337
323
338
324
- @_sync_model_wrapper
325
339
def barrier ():
326
340
"""Helper method to synchronize all processes.
327
341
"""
342
+ if _need_to_sync and isinstance (_model , _SerialModel ):
343
+ sync (temporary = True )
344
+
328
345
_model .barrier ()
329
346
330
347
@@ -356,11 +373,11 @@ def run(local_rank, *args, **kwargs):
356
373
ComputationModel ._ext_local_rank = index
357
374
358
375
359
- def _set_model (model ):
376
+ def _set_model (model , temporary = False ):
360
377
global _model , _need_to_sync
361
378
_model = model
362
379
_need_to_sync = True
363
- if not isinstance (_model , _SerialModel ):
380
+ if not isinstance (_model , _SerialModel ) and not temporary :
364
381
_need_to_sync = False
365
382
366
383
@@ -408,7 +425,7 @@ def train_fn(local_rank, a, b, c):
408
425
409
426
410
427
"""
411
- if not (has_xla_support or dist . is_available () ):
428
+ if not (has_xla_support or has_native_dist_support ):
412
429
# nothing to do => serial model
413
430
# maybe warn about this
414
431
return
0 commit comments