Skip to content

Commit 8ae6caa

Browse files
committed
Fix multiple typing issues
1 parent b977b74 commit 8ae6caa

File tree

6 files changed

+17
-23
lines changed

6 files changed

+17
-23
lines changed

ignite/distributed/auto.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def auto_dataloader(dataset: Dataset, **kwargs: Any) -> Union[DataLoader, "_MpDe
120120

121121
sampler = dataloader.sampler # type: ignore[union-attr]
122122
dataloader = mp_device_loader_cls(dataloader, idist.device())
123-
dataloader.sampler = sampler
123+
dataloader.sampler = sampler # type: ignore[attr-defined]
124124

125125
return dataloader
126126

@@ -320,12 +320,6 @@ def __init__(self, loader: Any, device: torch.device, **kwargs: Any) -> None:
320320
self._device = device
321321
self._parallel_loader_kwargs = kwargs
322322

323-
def __setattr__(self, name: str, value: Any) -> None:
324-
super().__setattr__(name, value)
325-
326-
def __getattr__(self, name: str) -> Any:
327-
super().__getattribute__(name)
328-
329323
def __iter__(self) -> Iterator:
330324
parallel_loader = ParallelLoader(self._loader, [self._device], **self._parallel_loader_kwargs)
331325
return parallel_loader.per_device_loader(self._device)

ignite/distributed/comp_models/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ class _SerialModel(ComputationModel):
221221
name = "serial"
222222
available_backends = ()
223223

224-
def __init__(self, _backend: Optional[str] = None, **_kwargs: Any) -> None:
224+
def __init__(self, _backend: Optional[str] = None, **kwargs: Any) -> None:
225225
super(_SerialModel, self).__init__()
226226

227227
def get_local_rank(self) -> int:
@@ -271,8 +271,8 @@ def spawn(*args: Any, **kwargs: Any) -> None:
271271
def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Union[torch.Tensor, Number]:
272272
return tensor
273273

274-
def all_gather(self, tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor, Number, List[str]]:
275-
return cast(Union[torch.Tensor, Number], tensor)
274+
def all_gather(self, tensor: Union[torch.Tensor, Number]) -> Union[torch.Tensor, Number]: # type: ignore
275+
return tensor
276276

277277
def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
278278
return tensor

ignite/distributed/comp_models/native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self, backend: Optional[str] = None, timeout: Optional[int] = None,
6363
else:
6464
self._init_from_context()
6565

66-
def _create_from_backend(self, backend: str, timeout: Optional[int] = None, **_kwargs: Any) -> None:
66+
def _create_from_backend(self, backend: str, timeout: Optional[int] = None, **kwargs: Any) -> None:
6767
if backend == dist.Backend.NCCL and not torch.cuda.is_available():
6868
raise RuntimeError("Nccl backend is required but no cuda capable devices")
6969

ignite/distributed/comp_models/xla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def __init__(self, backend: Optional[str] = None, **kwargs: Any):
5050
else:
5151
self._init_from_context()
5252

53-
def _create_from_backend(self, backend: str, **_kwargs: Any) -> None:
53+
def _create_from_backend(self, backend: str, **kwargs: Any) -> None:
5454
xm.rendezvous("init")
5555

5656
self._backend = backend

ignite/distributed/launcher.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,10 @@ def __init__(
216216
@staticmethod
217217
def _setup_spawn_params(
218218
nproc_per_node: int,
219-
nnodes: Optional[int],
220-
node_rank: Optional[int],
221-
master_addr: Optional[str],
222-
master_port: Optional[int],
219+
nnodes: Optional[int] = None,
220+
node_rank: Optional[int] = None,
221+
master_addr: Optional[str] = None,
222+
master_port: Optional[int] = None,
223223
**spawn_kwargs: Any
224224
) -> Dict:
225225
if nproc_per_node < 1:
@@ -273,7 +273,7 @@ def training(local_rank, config, **kwargs):
273273
**kwargs: keyword arguments of ``func``.
274274
275275
"""
276-
if self._spawn_params is not None:
276+
if self._spawn_params is not None and self.backend is not None:
277277
self.logger.info("Spawn function '{}' in {} processes".format(func, self._spawn_params["nproc_per_node"]))
278278
idist.spawn(self.backend, func, args=args, kwargs_dict=kwargs, **self._spawn_params)
279279
else:

ignite/distributed/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def hostname() -> str:
190190

191191

192192
def spawn(
193-
backend: Optional[str],
193+
backend: str,
194194
fn: Callable,
195195
args: Tuple,
196196
kwargs_dict: Optional[Mapping] = None,
@@ -349,7 +349,7 @@ def all_gather(tensor: Union[torch.Tensor, Number, str]) -> Union[torch.Tensor,
349349
if _need_to_sync and isinstance(_model, _SerialModel):
350350
sync(temporary=True)
351351

352-
return _model.all_gather(tensor)
352+
return _model.all_gather(tensor) # type: ignore[arg-type]
353353

354354

355355
def broadcast(tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
@@ -440,7 +440,7 @@ def _set_model(model: Any, temporary: bool = False) -> None:
440440
_need_to_sync = False
441441

442442

443-
def _assert_backend(backend: Optional[str]) -> None:
443+
def _assert_backend(backend: str) -> None:
444444
backends = available_backends()
445445
if backend not in backends:
446446
raise ValueError("Backend should be one of '{}'".format(backends))
@@ -527,7 +527,7 @@ def show_config() -> None:
527527
logger.info("node rank: {}".format(get_node_rank()))
528528

529529

530-
def one_rank_only(rank: int = 0, with_barrier: bool = False) -> Optional[Callable]:
530+
def one_rank_only(rank: int = 0, with_barrier: bool = False) -> Callable:
531531
"""Decorator to filter handlers wrt a rank number
532532
533533
Args:
@@ -549,9 +549,9 @@ def some_handler(_):
549549
...
550550
"""
551551

552-
def _one_rank_only(func: Callable) -> Optional[Callable]:
552+
def _one_rank_only(func: Callable) -> Callable:
553553
@wraps(func)
554-
def wrapper(*args: Any, **kwargs: Any) -> Optional[Callable]:
554+
def wrapper(*args: Any, **kwargs: Any) -> Optional[Any]:
555555
ret = None
556556
if get_rank() == rank:
557557
ret = func(*args, **kwargs)

0 commit comments

Comments
 (0)