Skip to content

Commit eb3a511

Browse files
committed
Fix all_gather gloo test
1 parent 88db78c commit eb3a511

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

ignite/distributed/comp_models/base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,9 @@ def all_reduce(self, tensor: Union[torch.Tensor, Number], op: str = "sum") -> Un
279279
def all_gather(
280280
self, tensor: Union[torch.Tensor, Number, str]
281281
) -> Union[torch.Tensor, Number, List[Number], List[str]]:
282-
return cast(Union[torch.Tensor, Number, List[Number], List[str]], tensor)
282+
if isinstance(tensor, torch.Tensor):
283+
return tensor
284+
return cast(Union[List[Number], List[str]], [tensor])
283285

284286
def broadcast(self, tensor: Union[torch.Tensor, Number, str], src: int = 0) -> Union[torch.Tensor, Number, str]:
285287
return tensor

tests/ignite/distributed/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _test_distrib_all_reduce(device):
8383

8484
def _test_distrib_all_gather(device):
8585

86-
res = idist.all_gather(10)
86+
res = torch.tensor(idist.all_gather(10))
8787
true_res = torch.tensor([10,] * idist.get_world_size(), device=device)
8888
assert (res == true_res).all()
8989

tests/ignite/distributed/utils/test_serial.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,5 @@ def test_idist_all_reduce_no_dist():
5454

5555

5656
def test_idist_all_gather_no_dist():
57-
assert idist.all_gather(10) == 10
57+
assert idist.all_gather(10) == [10]
58+
assert (idist.all_gather(torch.tensor(10)) == torch.tensor(10)).all()

0 commit comments

Comments
 (0)