File tree Expand file tree Collapse file tree 4 files changed +59
-0
lines changed
ignite/distributed/comp_models
tests/ignite/distributed/comp_models Expand file tree Collapse file tree 4 files changed +59
-0
lines changed Original file line number Diff line number Diff line change 1
1
import os
2
+ import warnings
2
3
from typing import Callable , Mapping , Optional , Tuple
3
4
4
5
import torch
@@ -97,6 +98,8 @@ def get_node_rank(self) -> int:
97
98
def device (self ) -> torch .device :
98
99
if torch .cuda .is_available ():
99
100
index = torch .cuda .current_device ()
101
+ if index < self .get_local_rank ():
102
+ warnings .warn ("Current device index is less than current local rank." )
100
103
return torch .device ("cuda:{}" .format (index ))
101
104
return torch .device ("cpu" )
102
105
Original file line number Diff line number Diff line change @@ -221,6 +221,8 @@ def get_node_rank(self) -> int:
221
221
def device (self ) -> torch .device :
222
222
if self .backend () == dist .Backend .NCCL :
223
223
index = torch .cuda .current_device ()
224
+ if index < self .get_local_rank ():
225
+ warnings .warn ("Current device index is less than current local rank." )
224
226
return torch .device ("cuda:{}" .format (index ))
225
227
return torch .device ("cpu" )
226
228
Original file line number Diff line number Diff line change @@ -195,3 +195,30 @@ def test__hvd_dist_model_spawn_cuda():
195
195
nproc_per_node = num_workers_per_machine ,
196
196
use_gloo = True ,
197
197
)
198
+
199
+
200
+ """
201
+ @pytest.mark.distributed
202
+ @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
203
+ @pytest.mark.timeout(10) // Doesn't do anything
204
+ def test__warning_if_deviceindex_less_than_localrank():
205
+ import os
206
+ import torch.distributed as dist
207
+ import ignite.distributed as idist
208
+
209
+ os.environ["RANK"] = "1"
210
+ os.environ["WORLD_SIZE"] = "1"
211
+ os.environ["MASTER_ADDR"] = "0.0.0.0"
212
+ os.environ["MASTER_PORT"] = "2222"
213
+
214
+ dist.init_process_group(backend="nccl", init_method="env://")
215
+
216
+ pytest.warns("Current device index is less than current local rank.", _HorovodDistModel.get_world_size)
217
+
218
+ dist.destroy_process_group()
219
+
220
+ del os.environ["RANK"]
221
+ del os.environ["WORLD_SIZE"]
222
+ del os.environ["MASTER_ADDR"]
223
+ del os.environ["MASTER_PORT"]
224
+ """
Original file line number Diff line number Diff line change @@ -299,3 +299,30 @@ def test__native_dist_model_spawn_gloo():
299
299
@pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Skip if no GPU" )
300
300
def test__native_dist_model_spawn_nccl ():
301
301
_test__native_dist_model_spawn ("nccl" , num_workers_per_machine = torch .cuda .device_count (), device = "cuda" )
302
+
303
+
304
+ """
305
+ @pytest.mark.distributed
306
+ @pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
307
+ @pytest.mark.timeout(10) // Doesn't do anything
308
+ def test__warning_if_deviceindex_less_than_localrank():
309
+ import os
310
+ import torch.distributed as dist
311
+ import ignite.distributed as idist
312
+
313
+ os.environ["RANK"] = "1"
314
+ os.environ["WORLD_SIZE"] = "1"
315
+ os.environ["MASTER_ADDR"] = "0.0.0.0"
316
+ os.environ["MASTER_PORT"] = "2222"
317
+
318
+ dist.init_process_group(backend="nccl", init_method="env://")
319
+
320
+ pytest.warns("Current device index is less than current local rank.", idist.get_world_size)
321
+
322
+ dist.destroy_process_group()
323
+
324
+ del os.environ["RANK"]
325
+ del os.environ["WORLD_SIZE"]
326
+ del os.environ["MASTER_ADDR"]
327
+ del os.environ["MASTER_PORT"]
328
+ """
You can’t perform that action at this time.
0 commit comments