Skip to content

Activate mypy in ignite.distributed #1355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 27 additions & 22 deletions ignite/distributed/auto.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import warnings
from typing import Any, Callable, Iterator, List, Optional, Union

import torch
import torch.nn as nn
Expand All @@ -16,7 +17,7 @@
__all__ = ["auto_dataloader", "auto_model", "auto_optim", "DistributedProxySampler"]


def auto_dataloader(dataset, **kwargs):
def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDeviceLoader"]:
"""Helper method to create a dataloader adapted for non-distributed and distributed configurations (supporting
all available backends from :meth:`~ignite.distributed.utils.available_backends()`).

Expand Down Expand Up @@ -74,7 +75,9 @@ def auto_dataloader(dataset, **kwargs):

if "batch_sampler" not in kwargs:
if kwargs.get("sampler", None) is not None:
sampler = DistributedProxySampler(kwargs["sampler"], num_replicas=world_size, rank=rank)
sampler = DistributedProxySampler(
kwargs["sampler"], num_replicas=world_size, rank=rank
) # type: Union[DistributedProxySampler, DistributedSampler, Sampler]
else:
sampler = DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=kwargs.get("shuffle", True)
Expand Down Expand Up @@ -115,9 +118,9 @@ def auto_dataloader(dataset, **kwargs):
except ImportError:
pass

sampler = dataloader.sampler
dataloader = mp_device_loader_cls(dataloader, idist.device())
dataloader.sampler = sampler
mp_dataloader = mp_device_loader_cls(dataloader, idist.device())
mp_dataloader.sampler = dataloader.sampler # type: ignore[attr-defined]
return mp_dataloader

return dataloader

Expand Down Expand Up @@ -266,32 +269,34 @@ class DistributedProxySampler(DistributedSampler):

"""

def __init__(self, sampler: Sampler, num_replicas=None, rank=None):
def __init__(self, sampler: Sampler, num_replicas: Optional[int] = None, rank: Optional[int] = None) -> None:

if not isinstance(sampler, Sampler):
raise TypeError("Argument sampler should be instance of torch Sampler, but given: {}".format(type(sampler)))

if not hasattr(sampler, "__len__"):
raise TypeError("Argument sampler should have length")

super(DistributedProxySampler, self).__init__(sampler, num_replicas=num_replicas, rank=rank, shuffle=False)
super(DistributedProxySampler, self).__init__(
sampler, num_replicas=num_replicas, rank=rank, shuffle=False # type: ignore[arg-type]
)
self.sampler = sampler

def __iter__(self):
def __iter__(self) -> Iterator:
# deterministically shuffle based on epoch
torch.manual_seed(self.epoch)
torch.manual_seed(self.epoch) # type: ignore[attr-defined]

indices = []
while len(indices) < self.total_size:
indices = [] # type: List
while len(indices) < self.total_size: # type: ignore[attr-defined]
indices += list(self.sampler)

if len(indices) > self.total_size:
indices = indices[: self.total_size]
if len(indices) > self.total_size: # type: ignore[attr-defined]
indices = indices[: self.total_size] # type: ignore[attr-defined]

# subsample
indices = indices[self.rank : self.total_size : self.num_replicas]
if len(indices) != self.num_samples:
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples))
indices = indices[self.rank : self.total_size : self.num_replicas] # type: ignore[attr-defined]
if len(indices) != self.num_samples: # type: ignore[attr-defined]
raise RuntimeError("{} vs {}".format(len(indices), self.num_samples)) # type: ignore[attr-defined]

return iter(indices)

Expand All @@ -304,22 +309,22 @@ def __iter__(self):
class _MpDeviceLoader:
# https://github.com/pytorch/xla/pull/2117
# From pytorch/xla if `torch_xla.distributed.parallel_loader.MpDeviceLoader` is not available
def __init__(self, loader, device, **kwargs):
def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
self._loader = loader
self._device = device
self._parallel_loader_kwargs = kwargs

def __iter__(self):
def __iter__(self) -> Iterator:
parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs)
return parallel_loader.per_device_loader(self._device)

def __len__(self):
def __len__(self) -> int:
return len(self._loader)

class _XLADistributedOptimizer(Optimizer):
def __init__(self, optimizer):
super(self.__class__, self).__init__(optimizer.param_groups)
def __init__(self, optimizer: Optimizer) -> None:
super(self.__class__, self).__init__(optimizer.param_groups) # type: ignore[call-arg]
self.wrapped_optimizer = optimizer

def step(self, closure=None):
def step(self, closure: Optional[Callable] = None) -> None:
xm.optimizer_step(self.wrapped_optimizer, barrier=True)
2 changes: 1 addition & 1 deletion ignite/distributed/comp_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ignite.distributed.comp_models.xla import has_xla_support


def setup_available_computation_models():
def setup_available_computation_models(): # type: ignore # inhomogeneous Tuple types are not supported
models = [
_SerialModel,
]
Expand Down
43 changes: 24 additions & 19 deletions ignite/distributed/comp_models/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import warnings
from abc import ABCMeta, abstractmethod
from numbers import Number
from typing import Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Union, cast

import torch

Expand All @@ -13,15 +13,15 @@ class ComputationModel(metaclass=ABCMeta):
"""

# this is an additional local rank storage used when idist is setup from existing native torch dist context
_ext_local_rank = None
_ext_local_rank = None # type: Optional[int]

def __init__(self):
self._backend = None
self._nproc_per_node = None
self._nnodes = None
self._node = None

def _setup_attrs(self):
def _setup_attrs(self) -> None:
if self._nproc_per_node is None:
self._nproc_per_node = self._compute_nproc_per_node() if self.get_world_size() > 1 else 1
if self._nnodes is None:
Expand Down Expand Up @@ -66,7 +66,7 @@ def backend(self) -> Optional[str]:
pass

@abstractmethod
def finalize(self):
def finalize(self) -> None:
pass

@staticmethod
Expand All @@ -76,15 +76,15 @@ def create_from_context() -> Optional["ComputationModel"]:

@staticmethod
@abstractmethod
def create_from_backend(backend: str, **kwargs) -> "ComputationModel":
def create_from_backend(backend: str, **kwargs: Any) -> "ComputationModel":
pass

@staticmethod
@abstractmethod
def spawn(*args, **kwargs):
def spawn(*args: Any, **kwargs: Any) -> None:
pass

_collective_op_dtype = None
_collective_op_dtype = None # type: Any

@staticmethod
def _encode_str(x: str, device: torch.device) -> torch.Tensor:
Expand All @@ -107,7 +107,9 @@ def _decode_str(xs: torch.Tensor) -> List[str]:
out = [bytearray(x[: x[-1]].tolist()).decode("utf-8") for x in xs]
return out

def _apply_op(self, tensor: torch.Tensor, device: torch.device, fn: Callable, *args, **kwargs) -> torch.Tensor:
def _apply_op(
self, tensor: torch.Tensor, device: torch.device, fn: Callable, *args: Any, **kwargs: Any
) -> torch.Tensor:
out_dtype = None
tensor_device = None

Expand All @@ -133,7 +135,7 @@ def _apply_op(self, tensor: torch.Tensor, device: torch.device, fn: Callable, *a
return tensor

def _collective_op(
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args, **kwargs
self, tensor: Union[torch.Tensor, Number, str], fn: Callable, *args: Any, **kwargs: Any
) -> Union[torch.Tensor, Number, List[str]]:
tensor_to_number = tensor_to_str = False
device = self.device()
Expand All @@ -147,7 +149,7 @@ def _collective_op(
tensor = self._apply_op(tensor, device, fn, *args, **kwargs)

if tensor_to_number and tensor.numel() == 1:
return tensor.item()
return cast(Number, tensor.item())
elif tensor_to_str:
return self._decode_str(tensor)
return tensor
Expand All @@ -156,7 +158,7 @@ def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Un
if not isinstance(tensor, (torch.Tensor, Number)):
raise TypeError("Unhandled input type {}".format(type(tensor)))

return self._collective_op(tensor, self._do_all_reduce, op)
return cast(Union[torch.Tensor, Number], self._collective_op(tensor, self._do_all_reduce, op))

def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]:
if not isinstance(tensor, (torch.Tensor, Number, str)):
Expand Down Expand Up @@ -189,7 +191,7 @@ def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> U
tensor = self._apply_op(tensor, device, self._do_broadcast, src)

if tensor_to_number:
return tensor.item()
return cast(Number, tensor.item())
if tensor_to_str:
list_str = self._decode_str(tensor)
return list_str[0]
Expand All @@ -208,7 +210,7 @@ def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass

@abstractmethod
def barrier(self):
def barrier(self) -> None:
pass


Expand All @@ -219,6 +221,9 @@ class _SerialModel(ComputationModel):
name = "serial"
available_backends = ()

def __init__(self, _backend: Optional[str] = None, **kwargs: Any) -> None:
super(_SerialModel, self).__init__()

def get_local_rank(self) -> int:
return 0

Expand All @@ -242,10 +247,10 @@ def device(self) -> torch.device:
return torch.device("cuda")
return torch.device("cpu")

def backend(self) -> None:
def backend(self) -> Optional[str]:
return None

def finalize(self):
def finalize(self) -> None:
pass

def _compute_nproc_per_node(self) -> int:
Expand All @@ -256,17 +261,17 @@ def create_from_context() -> "_SerialModel":
return _SerialModel()

@staticmethod
def create_from_backend(backend: Optional[str] = None, **kwargs) -> "_SerialModel":
def create_from_backend(backend: Optional[str] = None, **kwargs: Any) -> "_SerialModel":
return _SerialModel()

@staticmethod
def spawn(*args, **kwargs):
def spawn(*args: Any, **kwargs: Any) -> None:
raise NotImplementedError("Serial computation model does not implement spawn method")

def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]:
return tensor

def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]:
def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]: # type: ignore
return tensor

def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
Expand All @@ -281,5 +286,5 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
pass

def barrier(self):
def barrier(self) -> None:
pass
24 changes: 12 additions & 12 deletions ignite/distributed/comp_models/horovod.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Callable, Mapping, Optional, Tuple
from typing import Any, Callable, Mapping, Optional, Tuple

import torch

Expand Down Expand Up @@ -33,7 +33,7 @@ class _HorovodDistModel(ComputationModel):
available_backends = (HOROVOD,)

@staticmethod
def _get_hvd_rank():
def _get_hvd_rank() -> int:
try:
rank = hvd.rank()
except ValueError as e:
Expand All @@ -48,7 +48,7 @@ def create_from_context() -> Optional["_HorovodDistModel"]:
return _HorovodDistModel()

@staticmethod
def create_from_backend(backend: str, **kwargs) -> "_HorovodDistModel":
def create_from_backend(backend: str, **kwargs: Any) -> "_HorovodDistModel":
if backend not in _HorovodDistModel.available_backends:
raise ValueError("Backend should be one of '{}'".format(_HorovodDistModel.available_backends))

Expand All @@ -57,7 +57,7 @@ def create_from_backend(backend: str, **kwargs) -> "_HorovodDistModel":
raise RuntimeError("Can not re-initialize Horovod if it is already initialized")
return _HorovodDistModel(do_init=True, **kwargs)

def __init__(self, do_init=False, **kwargs):
def __init__(self, do_init: bool = False, **kwargs: Any) -> None:
"""This is a private method. Please, use `create_from_backend` or `create_from_context`
"""
super(_HorovodDistModel, self).__init__()
Expand All @@ -73,7 +73,7 @@ def __init__(self, do_init=False, **kwargs):

self._setup_attrs()

def _compute_nproc_per_node(self):
def _compute_nproc_per_node(self) -> int:
return hvd.local_size()

def get_local_rank(self) -> int:
Expand Down Expand Up @@ -103,11 +103,11 @@ def device(self) -> torch.device:
def backend(self) -> str:
return self._backend

def finalize(self):
def finalize(self) -> None:
hvd.shutdown()

@staticmethod
def _dist_worker_task_fn(backend, fn, args, kwargs_dict):
def _dist_worker_task_fn(backend: str, fn: Callable, args: Tuple, kwargs_dict: Mapping) -> None:
from ignite.distributed.utils import _set_model, finalize

model = _HorovodDistModel.create_from_backend(backend)
Expand All @@ -116,15 +116,15 @@ def _dist_worker_task_fn(backend, fn, args, kwargs_dict):
finalize()

@staticmethod
def spawn(
def spawn( # type: ignore[override]
fn: Callable,
args: Tuple,
kwargs_dict: Optional[Mapping] = None,
nproc_per_node: int = 1,
hosts=None,
hosts: Optional[str] = None,
backend: str = HOROVOD,
**kwargs
):
**kwargs: Any
) -> None:
c1 = "nnodes" in kwargs and kwargs["nnodes"] > 1
c2 = "node_rank" in kwargs and kwargs["node_rank"] > 0
if c1 or c2:
Expand Down Expand Up @@ -166,7 +166,7 @@ def _do_all_gather(self, tensor: torch.Tensor) -> torch.Tensor:
def _do_broadcast(self, tensor: torch.Tensor, src: int) -> torch.Tensor:
return hvd.broadcast(tensor, root_rank=src)

def barrier(self):
def barrier(self) -> None:
# https://github.com/horovod/horovod/issues/159#issuecomment-424834603
# hvd.allreduce(torch.tensor(0, device=self.device()), name="barrier")
hvd.allreduce(torch.tensor(0, device="cpu"), name="barrier")
Loading