Skip to content

Commit eaac5ea

Browse files
committed
warning if current device index is lower than current local rank
1 parent d384dc6 commit eaac5ea

File tree

4 files changed

+50
-0
lines changed

4 files changed

+50
-0
lines changed

ignite/distributed/comp_models/horovod.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import warnings
23
from typing import Callable, Mapping, Optional, Tuple
34

45
import torch
@@ -97,6 +98,8 @@ def get_node_rank(self) -> int:
9798
def device(self) -> torch.device:
9899
if torch.cuda.is_available():
99100
index = torch.cuda.current_device()
101+
if index < self.get_local_rank():
102+
warnings.warn("Current device index is less than current local rank.")
100103
return torch.device("cuda:{}".format(index))
101104
return torch.device("cpu")
102105

ignite/distributed/comp_models/native.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ def get_node_rank(self) -> int:
221221
def device(self) -> torch.device:
222222
if self.backend() == dist.Backend.NCCL:
223223
index = torch.cuda.current_device()
224+
if index < self.get_local_rank():
225+
warnings.warn("Current device index is less than current local rank.")
224226
return torch.device("cuda:{}".format(index))
225227
return torch.device("cpu")
226228

tests/ignite/distributed/comp_models/test_horovod.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,25 @@ def test__hvd_dist_model_spawn_cuda():
195195
nproc_per_node=num_workers_per_machine,
196196
use_gloo=True,
197197
)
198+
199+
200+
@pytest.mark.distributed
201+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
202+
def test__warning_if_deviceindex_less_than_localrank(world_size):
203+
import os
204+
import torch.distributed as dist
205+
206+
os.environ["RANK"] = "1"
207+
os.environ["MASTER_ADDR"] = "0.0.0.0"
208+
os.environ["MASTER_PORT"] = "2222"
209+
210+
dist.init_process_group(backend="nccl", init_method="env://")
211+
212+
with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."):
213+
_HorovodDistModel.get_world_size()
214+
215+
dist.destroy_process_group()
216+
217+
del os.environ["RANK"]
218+
del os.environ["MASTER_ADDR"]
219+
del os.environ["MASTER_PORT"]

tests/ignite/distributed/comp_models/test_native.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,3 +299,26 @@ def test__native_dist_model_spawn_gloo():
299299
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
300300
def test__native_dist_model_spawn_nccl():
301301
_test__native_dist_model_spawn("nccl", num_workers_per_machine=torch.cuda.device_count(), device="cuda")
302+
303+
304+
@pytest.mark.distributed
305+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
306+
def test__warning_if_deviceindex_less_than_localrank(world_size):
307+
import os
308+
import torch.distributed as dist
309+
import ignite.distributed as idist
310+
311+
os.environ["RANK"] = "1"
312+
os.environ["MASTER_ADDR"] = "0.0.0.0"
313+
os.environ["MASTER_PORT"] = "2222"
314+
315+
dist.init_process_group(backend="nccl", init_method="env://")
316+
317+
with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."):
318+
idist.get_world_size()
319+
320+
dist.destroy_process_group()
321+
322+
del os.environ["RANK"]
323+
del os.environ["MASTER_ADDR"]
324+
del os.environ["MASTER_PORT"]

0 commit comments

Comments
 (0)