Skip to content

Commit 68c6cf7

Browse files
committed
Add simple test
update refine test cases update ut add mpi check update datatype map update update update update use lintrunner format code update
1 parent 0cfd224 commit 68c6cf7

File tree

10 files changed

+352
-47
lines changed

10 files changed

+352
-47
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,7 @@ def run(self):
648648
if cmake_cache_vars["USE_XCCL"]:
649649
report("-- Building XCCL library")
650650
else:
651-
report("-- Not using XCCL")
651+
report("-- Not using XCCL")
652652
if cmake_cache_vars["USE_DISTRIBUTED"]:
653653
if IS_WINDOWS:
654654
report("-- Building without distributed package")

test/distributed/test_c10d_common.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,13 @@ def gpus_for_rank(world_size):
6666
On a single node, all visible GPUs are evenly
6767
divided to subsets, each process only uses a subset.
6868
"""
69-
visible_devices = list(range(torch.cuda.device_count()))
70-
gpus_per_process = torch.cuda.device_count() // world_size
69+
device_count = (
70+
torch.xpu.device_count()
71+
if torch.xpu.is_available()
72+
else torch.cuda.device_count()
73+
)
74+
visible_devices = list(range(device_count))
75+
gpus_per_process = device_count // world_size
7176
gpus_for_rank = []
7277
for rank in range(world_size):
7378
gpus_for_rank.append(

test/distributed/test_c10d_xccl.py

Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# Owner(s): ["oncall: distributed"]
2+
3+
import math
4+
import os
5+
import sys
6+
import time
7+
from datetime import timedelta
8+
from unittest import mock
9+
10+
import torch
11+
import torch.distributed as c10d
12+
13+
14+
if not c10d.is_available() or not c10d.is_xccl_available():
15+
print("c10d XCCL not available, skipping tests", file=sys.stderr)
16+
sys.exit(0)
17+
18+
import test_c10d_common
19+
20+
import torch.distributed as dist
21+
import torch.testing._internal.common_utils as common
22+
from torch.testing._internal.common_distributed import (
23+
init_multigpu_helper,
24+
MultiProcessTestCase,
25+
requires_xccl,
26+
)
27+
from torch.testing._internal.common_utils import (
28+
retry_on_connect_failures,
29+
run_tests,
30+
skip_but_pass_in_sandcastle_if,
31+
TEST_XPU,
32+
TestCase,
33+
)
34+
35+
36+
def simple_reduce_tests(rank, world_size):
37+
tests = [
38+
(
39+
c10d.ReduceOp.SUM,
40+
torch.tensor([rank + 1.0]),
41+
torch.tensor([float(world_size * (world_size + 1) / 2)]),
42+
),
43+
(
44+
c10d.ReduceOp.PRODUCT,
45+
torch.tensor([rank + 1.0]),
46+
torch.tensor([float(math.factorial(world_size))]),
47+
),
48+
(
49+
c10d.ReduceOp.MIN,
50+
torch.tensor([rank + 1.0]),
51+
torch.tensor([1.0]),
52+
),
53+
(
54+
c10d.ReduceOp.MAX,
55+
torch.tensor([rank + 1.0]),
56+
torch.tensor([world_size]),
57+
),
58+
]
59+
60+
return tests
61+
62+
63+
TEST_MULTIXPU = torch.xpu.device_count() > 1
64+
65+
66+
class RendezvousEnvTest(TestCase):
67+
@retry_on_connect_failures
68+
@requires_xccl()
69+
@skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test")
70+
def test_common_errors(self):
71+
vars = {
72+
"WORLD_SIZE": "1",
73+
"RANK": "0",
74+
"MASTER_ADDR": "127.0.0.1",
75+
"MASTER_PORT": str(common.find_free_port()),
76+
}
77+
78+
class Env:
79+
def __init__(self, vars):
80+
self.env_patcher = mock.patch.dict(os.environ, vars, clear=True)
81+
82+
def __enter__(self):
83+
self.env_patcher.start()
84+
85+
def __exit__(self, type, value, traceback):
86+
self.env_patcher.stop()
87+
88+
def without(d, key):
89+
d = d.copy()
90+
d.pop(key)
91+
return d
92+
93+
def withouts(d, keys):
94+
d = d.copy()
95+
for key in keys:
96+
d.pop(key)
97+
return d
98+
99+
with Env(without(vars, "WORLD_SIZE")):
100+
self.assertEqual(None, os.environ.get("WORLD_SIZE"))
101+
with self.assertRaisesRegex(ValueError, "WORLD_SIZE expected"):
102+
gen = c10d.rendezvous("env://")
103+
next(gen)
104+
c10d.init_process_group(backend="xccl", world_size=1)
105+
self.assertEqual(c10d.get_rank(), 0)
106+
self.assertEqual(c10d.get_world_size(), 1)
107+
c10d.destroy_process_group()
108+
109+
with Env(without(vars, "RANK")):
110+
self.assertEqual(None, os.environ.get("RANK"))
111+
with self.assertRaisesRegex(ValueError, "RANK expected"):
112+
gen = c10d.rendezvous("env://")
113+
next(gen)
114+
c10d.init_process_group(backend="xccl", rank=0)
115+
self.assertEqual(c10d.get_rank(), 0)
116+
self.assertEqual(c10d.get_world_size(), 1)
117+
c10d.destroy_process_group()
118+
119+
with Env(withouts(vars, ["RANK", "WORLD_SIZE"])):
120+
self.assertEqual(None, os.environ.get("RANK"))
121+
self.assertEqual(None, os.environ.get("WORLD_SIZE"))
122+
c10d.init_process_group(backend="xccl", rank=0, world_size=1)
123+
self.assertEqual(c10d.get_rank(), 0)
124+
self.assertEqual(c10d.get_world_size(), 1)
125+
c10d.destroy_process_group()
126+
127+
with Env(vars):
128+
c10d.init_process_group(backend="xccl")
129+
self.assertEqual(c10d.get_rank(), 0)
130+
self.assertEqual(c10d.get_world_size(), 1)
131+
c10d.destroy_process_group()
132+
133+
with Env(without(vars, "MASTER_ADDR")):
134+
self.assertEqual(None, os.environ.get("MASTER_ADDR"))
135+
with self.assertRaisesRegex(ValueError, "MASTER_ADDR expected"):
136+
gen = c10d.rendezvous("env://")
137+
next(gen)
138+
139+
with Env(without(vars, "MASTER_PORT")):
140+
self.assertEqual(None, os.environ.get("MASTER_PORT"))
141+
with self.assertRaisesRegex(ValueError, "MASTER_PORT expected"):
142+
gen = c10d.rendezvous("env://")
143+
next(gen)
144+
145+
with Env(without(vars, "WORLD_SIZE")):
146+
self.assertEqual(None, os.environ.get("WORLD_SIZE"))
147+
gen = c10d.rendezvous(f"env://?world_size={1}")
148+
_, _, size = next(gen)
149+
self.assertEqual(size, 1)
150+
151+
with Env(without(vars, "RANK")):
152+
self.assertEqual(None, os.environ.get("RANK"))
153+
gen = c10d.rendezvous(f"env://?rank={0}")
154+
_, rank, _ = next(gen)
155+
self.assertEqual(rank, 0)
156+
157+
with Env(withouts(vars, ["RANK", "WORLD_SIZE"])):
158+
self.assertEqual(None, os.environ.get("RANK"))
159+
self.assertEqual(None, os.environ.get("WORLD_SIZE"))
160+
gen = c10d.rendezvous(f"env://?rank={0}&world_size={1}")
161+
_, rank, size = next(gen)
162+
self.assertEqual(rank, 0)
163+
self.assertEqual(size, 1)
164+
165+
166+
class TimeoutTest(test_c10d_common.AbstractTimeoutTest, TestCase):
167+
@requires_xccl()
168+
@retry_on_connect_failures
169+
@skip_but_pass_in_sandcastle_if(not TEST_XPU, "No GPUs available, skipping test")
170+
def test_default_store_timeout_nccl(self):
171+
self._test_default_store_timeout("xccl")
172+
173+
174+
class ProcessGroupXCCLTest(MultiProcessTestCase):
175+
def _create_process_group_xccl(
176+
self, timeout=timedelta(seconds=600), device_id=None
177+
):
178+
store = c10d.FileStore(self.file_name, self.world_size)
179+
c10d.init_process_group(
180+
"xccl",
181+
world_size=self.world_size,
182+
rank=self.rank,
183+
store=store,
184+
timeout=timeout,
185+
device_id=device_id,
186+
)
187+
pg = c10d.distributed_c10d._get_default_group()
188+
return pg
189+
190+
def setUp(self):
191+
super().setUp()
192+
self._spawn_processes()
193+
194+
def tearDown(self):
195+
super().tearDown()
196+
try:
197+
os.remove(self.file_name)
198+
except OSError:
199+
pass
200+
201+
@property
202+
def world_size(self):
203+
return 2
204+
205+
@property
206+
def rank_to_GPU(self):
207+
# return rank to GPU map
208+
return init_multigpu_helper(self.world_size, "xccl")
209+
210+
@requires_xccl()
211+
@skip_but_pass_in_sandcastle_if(
212+
torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs"
213+
)
214+
def test_close_multi_pg_unordered(self):
215+
pg = self._create_process_group_xccl()
216+
device = self.rank_to_GPU[self.rank][0]
217+
t = torch.rand(10, 10, device=device)
218+
# First allreduce to initialize default PG's communicator.
219+
pg.allreduce(t).wait()
220+
new_pg1 = c10d.new_group([0, 1])
221+
new_pg2 = c10d.new_group([0, 1])
222+
if self.rank == 0 or self.rank == 1:
223+
t1 = torch.rand(10, 10, device=device)
224+
t2 = torch.rand(10, 10, device=device)
225+
new_pg1.allreduce(t1).wait()
226+
new_pg2.allreduce(t2).wait()
227+
if self.rank == 0:
228+
dist.destroy_process_group(new_pg2)
229+
# force destruction of pg2 first
230+
del new_pg2
231+
dist.destroy_process_group(new_pg1)
232+
del new_pg1
233+
if self.rank == 1:
234+
c10d.destroy_process_group(new_pg1)
235+
# force destruction of pg1 first
236+
del new_pg1
237+
dist.destroy_process_group(new_pg2)
238+
del new_pg2
239+
dist.destroy_process_group()
240+
241+
@requires_xccl()
242+
@skip_but_pass_in_sandcastle_if(
243+
torch.xpu.device_count() < 2, "XCCL test requires 2+ GPUs"
244+
)
245+
def test_file_store_check(self):
246+
# self.file_name is created using "delete=False"
247+
# e.g., self.file_name = tempfile.NamedTemporaryFile(delete=False).name
248+
store = dist.FileStore(self.file_name, self.world_size)
249+
dist.init_process_group(
250+
backend="xccl", rank=self.rank, world_size=self.world_size, store=store
251+
)
252+
pg = dist.distributed_c10d._get_default_group()
253+
self.assertEqual(pg.rank(), self.rank)
254+
self.assertEqual(pg.size(), self.world_size)
255+
# give enough time for check() to be executed multiple times
256+
time.sleep(2)
257+
dist.destroy_process_group()
258+
259+
@requires_xccl()
260+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIXPU, "XCCL test requires 2+ GPUs")
261+
def test_set_process_group_desc(self):
262+
device = torch.device(f"xpu:{self.rank}")
263+
pg_default = self._create_process_group_xccl(device_id=device)
264+
self.assertEqual(pg_default.group_desc, "default_pg")
265+
pg_1 = c10d.new_group([0, 1], group_desc="test_purpose")
266+
self.assertEqual(pg_1.group_desc, "test_purpose")
267+
pg_2 = c10d.new_group([0, 1])
268+
self.assertEqual(pg_2.group_desc, "undefined")
269+
270+
def _test_allreduce_basics(self, fn):
271+
pg = self._create_process_group_xccl()
272+
device = torch.device("xpu:" + str(self.rank))
273+
# Single input tests
274+
tests = simple_reduce_tests(self.rank, self.world_size)
275+
for op, input, expected in tests:
276+
opts = c10d.AllreduceOptions()
277+
opts.reduceOp = op
278+
tensor = fn(input.to(device))
279+
fut = pg.allreduce([tensor], opts).get_future()
280+
fut.wait()
281+
result = fut.value()
282+
self.assertEqual(expected, result[0], exact_dtype=False)
283+
284+
x = fn(torch.tensor([self.rank + 1.0], device=device))
285+
fut = pg.allreduce(x).get_future()
286+
fut.wait()
287+
result = fut.value()
288+
self.assertEqual(
289+
torch.tensor([float(self.world_size * (self.world_size + 1) / 2)]),
290+
result[0],
291+
)
292+
293+
@requires_xccl()
294+
def test_allreduce_basics(self):
295+
self._test_allreduce_basics(lambda t: t.clone())
296+
297+
298+
if __name__ == "__main__":
299+
assert (
300+
not torch.xpu._initialized
301+
), "test_distributed must not have initialized XPU context on main process"
302+
303+
run_tests()

torch/_C/_distributed_c10d.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -705,4 +705,4 @@ class ProcessGroupXCCL(Backend):
705705
store: Store,
706706
rank: int,
707707
size: int,
708-
): ...
708+
): ...

torch/csrc/distributed/c10d/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,7 @@ namespace {
510510
#define REGISTER_C10D_OP(FUNC) \
511511
REGISTER_C10D_OP1(FUNC, CPU) \
512512
REGISTER_C10D_OP1(FUNC, CUDA) \
513-
REGISTER_C10D_OP1(FUNC, XPU) \
513+
REGISTER_C10D_OP1(FUNC, XPU) \
514514
REGISTER_C10D_OP1(FUNC, PrivateUse1)
515515

516516
// Now we start to register ops with the three device keys

0 commit comments

Comments
 (0)