Skip to content

warning if current device index is lower than current local rank (#1335) #1376

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import warnings
from typing import Any, Callable, Mapping, Optional, Tuple

import torch
Expand Down Expand Up @@ -68,7 +69,7 @@ def __init__(self, do_init: bool = False, **kwargs: Any) -> None:

self._local_rank = hvd.local_rank()

if torch.cuda.is_available():
if do_init and torch.cuda.is_available():
torch.cuda.set_device(self._local_rank)

self._setup_attrs()
Expand Down Expand Up @@ -97,6 +98,11 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if torch.cuda.is_available():
index = torch.cuda.current_device()
if index < self.get_local_rank():
warnings.warn(
"Current device index is less than current local rank. "
"Please, make sure to call torch.cuda.set_device(local_rank)."
)
return torch.device("cuda:{}".format(index))
return torch.device("cpu")

Expand Down
12 changes: 11 additions & 1 deletion ignite/distributed/comp_models/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ def _init_from_context(self) -> None:
self._setup_attrs()

def _compute_nproc_per_node(self) -> int:
tensor = torch.tensor([self.get_local_rank() + 1]).to(self.device())
local_rank = self.get_local_rank()
device = torch.device("cpu")
if self.backend() == dist.Backend.NCCL:
# we manually set cuda device to local rank in order to avoid a hang on all_reduce
device = torch.device("cuda:{}".format(local_rank))
tensor = torch.tensor([self.get_local_rank() + 1]).to(device)
dist.all_reduce(tensor, op=dist.ReduceOp.MAX)
return int(tensor.item())

Expand Down Expand Up @@ -220,6 +225,11 @@ def get_node_rank(self) -> int:
def device(self) -> torch.device:
if self.backend() == dist.Backend.NCCL:
index = torch.cuda.current_device()
if index < self.get_local_rank():
warnings.warn(
"Current device index is less than current local rank. "
"Please, make sure to call torch.cuda.set_device(local_rank)."
)
return torch.device("cuda:{}".format(index))
return torch.device("cpu")

Expand Down
4 changes: 4 additions & 0 deletions tests/ignite/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ def _hvd_task_with_init(func, args):
import horovod.torch as hvd

hvd.init()
lrank = hvd.local_rank()
if torch.cuda.is_available():
torch.cuda.set_device(lrank)

func(*args)
hvd.shutdown()

Expand Down
46 changes: 42 additions & 4 deletions tests/ignite/distributed/comp_models/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ def _test__hvd_dist_model_create_from_backend_no_dist(backend, true_device):
model = _HorovodDistModel.create_from_backend(backend=backend)

assert hvd.rank() > -1
print("true_device", true_device)
_assert_model(
model,
{
Expand Down Expand Up @@ -109,10 +108,13 @@ def _test__hvd_dist_model_create_from_context_dist(true_backend, true_device):
assert _HorovodDistModel.create_from_context() is None

hvd.init()
lrank = hvd.local_rank()
if torch.cuda.is_available():
torch.cuda.set_device(lrank)

true_conf = {
"device": true_device,
"local_rank": hvd.local_rank(),
"local_rank": lrank,
"rank": hvd.rank(),
"world_size": hvd.size(),
"node_index": 0,
Expand All @@ -121,6 +123,7 @@ def _test__hvd_dist_model_create_from_context_dist(true_backend, true_device):
}

model = _HorovodDistModel.create_from_context()
assert model.backend() == true_backend
_assert_model(model, true_conf)

hvd.shutdown()
Expand All @@ -142,18 +145,53 @@ def test__hvd_dist_model_create_no_dist_cuda(gloo_hvd_executor):

@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Skip if has GPU")
def test__hvd_dist_model_create_dist(gloo_hvd_executor):
def test__hvd_dist_model_create_dist_1(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_create_from_backend_dist, ("horovod", "cpu"), np=4)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() > 0, reason="Skip if has GPU")
def test__hvd_dist_model_create_dist_2(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_create_from_context_dist, ("horovod", "cpu"), np=4)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__hvd_dist_model_create_dist_cuda(gloo_hvd_executor):
def test__hvd_dist_model_create_dist_cuda_1(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_create_from_backend_dist, ("horovod", "cuda"), np=torch.cuda.device_count())


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__hvd_dist_model_create_dist_cuda_2(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_create_from_context_dist, ("horovod", "cuda"), np=torch.cuda.device_count())


def _test__hvd_dist_model_warning_index_less_localrank():

assert torch.cuda.is_available()
assert _HorovodDistModel.create_from_context() is None

hvd.init()
# We deliberately incorrectly set cuda device to 0
torch.cuda.set_device(0)

model = _HorovodDistModel.create_from_context()
assert isinstance(model, _HorovodDistModel), "{} vs _HorovodDistModel".format(type(model))

if hvd.local_rank() == 1:
with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."):
model.device()

hvd.shutdown()


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
def test__hvd_dist_model_warning_index_less_localrank(gloo_hvd_executor):
gloo_hvd_executor(_test__hvd_dist_model_warning_index_less_localrank, (), np=torch.cuda.device_count())


def _test_dist_spawn_fn(local_rank, backend, world_size, device):
from ignite.distributed.utils import _model

Expand Down
36 changes: 34 additions & 2 deletions tests/ignite/distributed/comp_models/test_native.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ def _test__native_dist_model_create_from_context_dist(local_rank, rank, world_si

dist.init_process_group(true_backend, "tcp://0.0.0.0:2222", world_size=world_size, rank=rank)
dist.barrier()
if torch.cuda.is_available():
torch.cuda.set_device(local_rank)

true_conf = {
"device": true_device,
Expand Down Expand Up @@ -244,22 +246,52 @@ def test__native_dist_model_create_no_dist_nccl(clean_env):


@pytest.mark.distributed
def test__native_dist_model_create_dist_gloo(local_rank, world_size):
def test__native_dist_model_create_dist_gloo_1(local_rank, world_size):
_test__native_dist_model_create_from_backend_dist(local_rank, local_rank, world_size, "gloo", "cpu")


@pytest.mark.distributed
def test__native_dist_model_create_dist_gloo_2(local_rank, world_size):
_test__native_dist_model_create_from_context_dist(local_rank, local_rank, world_size, "gloo", "cpu")


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__native_dist_model_create_dist_nccl(local_rank, world_size):
def test__native_dist_model_create_dist_nccl_1(local_rank, world_size):
_test__native_dist_model_create_from_backend_dist(
local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank)
)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
def test__native_dist_model_create_dist_nccl_2(local_rank, world_size):
_test__native_dist_model_create_from_context_dist(
local_rank, local_rank, world_size, "nccl", "cuda:{}".format(local_rank)
)


@pytest.mark.distributed
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Skip if less than 2 GPUs")
def test__native_dist_model_warning_index_less_localrank(local_rank, world_size):

assert _NativeDistModel.create_from_context() is None

dist.init_process_group("nccl", "tcp://0.0.0.0:2222", world_size=world_size, rank=local_rank)
dist.barrier()
# We deliberately incorrectly set cuda device to 0
torch.cuda.set_device(0)

model = _NativeDistModel.create_from_context()
assert isinstance(model, _NativeDistModel), "{} vs _NativeDistModel".format(type(model))

if local_rank == 1:
with pytest.warns(UserWarning, match=r"Current device index is less than current local rank."):
model.device()

dist.destroy_process_group()


def _test_dist_spawn_fn(local_rank, backend, world_size, device):
from ignite.distributed.utils import _model

Expand Down
7 changes: 7 additions & 0 deletions tests/ignite/distributed/utils/test_horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ def _test_sync_as_hvd():
from ignite.distributed.comp_models.horovod import _HorovodDistModel

hvd.init()
lrank = hvd.local_rank()
if torch.cuda.is_available():
torch.cuda.set_device(lrank)

_test_sync(_HorovodDistModel)

Expand Down Expand Up @@ -111,6 +114,10 @@ def _test_idist_methods_in_hvd_context(backend, device):
ws = hvd.size()
rank = hvd.rank()
local_rank = hvd.local_rank()

if torch.cuda.is_available():
torch.cuda.set_device(local_rank)

_test_distrib_config(local_rank, backend=backend, ws=ws, true_device=device, rank=rank)

hvd.shutdown()
Expand Down
4 changes: 2 additions & 2 deletions tests/run_cpu_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

set -xeu

CUDA_VISIBLE_DEVICES="" py.test --tx 4*popen//python=python$CI_PYTHON_VERSION --cov ignite --cov-report term-missing -vvv tests/
CUDA_VISIBLE_DEVICES="" py.test --tx 4*popen//python=python${CI_PYTHON_VERSION:-3.7} --cov ignite --cov-report term-missing -vvv tests/

# https://pubs.opengroup.org/onlinepubs/009695399/utilities/xcu_chap02.html#tag_02_06_02
if [ "${SKIP_DISTRIB_TESTS:-0}" -eq "1" ]; then
Expand All @@ -11,4 +11,4 @@ fi

export WORLD_SIZE=2

CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python$CI_PYTHON_VERSION tests -m distributed -vvv
CUDA_VISIBLE_DEVICES="" py.test --cov ignite --cov-append --cov-report term-missing --dist=each --tx $WORLD_SIZE*popen//python=python${CI_PYTHON_VERSION:-3.7} tests -m distributed -vvv