Skip to content

Commit 16d021d

Browse files
committed
warning if current device index is lower than current local rank
1 parent d384dc6 commit 16d021d

File tree

4 files changed

+59
-0
lines changed

4 files changed

+59
-0
lines changed

ignite/distributed/comp_models/horovod.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23
from typing import Callable, Mapping, Optional, Tuple
34

45
import torch
@@ -97,6 +98,8 @@ def get_node_rank(self) -> int:
9798
def device(self) -> torch.device:
9899
if torch.cuda.is_available():
99100
index = torch.cuda.current_device()
101+
if index < self.get_local_rank():
102+
warnings.warn("Current device index is less than current local rank.")
100103
return torch.device("cuda:{}".format(index))
101104
return torch.device("cpu")
102105

ignite/distributed/comp_models/native.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def get_node_rank(self) -> int:
221221
def device(self) -> torch.device:
222222
if self.backend() == dist.Backend.NCCL:
223223
index = torch.cuda.current_device()
224+
if index < self.get_local_rank():
225+
warnings.warn("Current device index is less than current local rank.")
224226
return torch.device("cuda:{}".format(index))
225227
return torch.device("cpu")
226228

tests/ignite/distributed/comp_models/test_horovod.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,30 @@ def test__hvd_dist_model_spawn_cuda():
195195
nproc_per_node=num_workers_per_machine,
196196
use_gloo=True,
197197
)
198+
199+
200+
"""
201+
@pytest.mark.distributed
202+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
203+
@pytest.mark.timeout(10) // Doesn't do anything
204+
def test__warning_if_deviceindex_less_than_localrank():
205+
import os
206+
import torch.distributed as dist
207+
import ignite.distributed as idist
208+
209+
os.environ["RANK"] = "1"
210+
os.environ["WORLD_SIZE"] = "1"
211+
os.environ["MASTER_ADDR"] = "0.0.0.0"
212+
os.environ["MASTER_PORT"] = "2222"
213+
214+
dist.init_process_group(backend="nccl", init_method="env://")
215+
216+
pytest.warns("Current device index is less than current local rank.", _HorovodDistModel.get_world_size)
217+
218+
dist.destroy_process_group()
219+
220+
del os.environ["RANK"]
221+
del os.environ["WORLD_SIZE"]
222+
del os.environ["MASTER_ADDR"]
223+
del os.environ["MASTER_PORT"]
224+
"""

tests/ignite/distributed/comp_models/test_native.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,30 @@ def test__native_dist_model_spawn_gloo():
299299
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
300300
def test__native_dist_model_spawn_nccl():
301301
_test__native_dist_model_spawn("nccl", num_workers_per_machine=torch.cuda.device_count(), device="cuda")
302+
303+
304+
"""
305+
@pytest.mark.distributed
306+
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
307+
@pytest.mark.timeout(10) // Doesn't do anything
308+
def test__warning_if_deviceindex_less_than_localrank():
309+
import os
310+
import torch.distributed as dist
311+
import ignite.distributed as idist
312+
313+
os.environ["RANK"] = "1"
314+
os.environ["WORLD_SIZE"] = "1"
315+
os.environ["MASTER_ADDR"] = "0.0.0.0"
316+
os.environ["MASTER_PORT"] = "2222"
317+
318+
dist.init_process_group(backend="nccl", init_method="env://")
319+
320+
pytest.warns("Current device index is less than current local rank.", idist.get_world_size)
321+
322+
dist.destroy_process_group()
323+
324+
del os.environ["RANK"]
325+
del os.environ["WORLD_SIZE"]
326+
del os.environ["MASTER_ADDR"]
327+
del os.environ["MASTER_PORT"]
328+
"""

0 commit comments

Comments
 (0)