Skip to content

Move PJRT Python APIs out of torch_xla.experimental.* #5011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions test/pjrt/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from torch.nn.parallel import DistributedDataParallel as DDP
import torch_xla.core.xla_model as xm
import torch_xla.experimental.pjrt_backend
from torch_xla.experimental import pjrt, tpu
from torch_xla import runtime as xr
from torch_xla._internal import pjrt, tpu

# Setup import folders.
xla_test_folder = os.path.dirname(os.path.dirname(os.path.abspath(sys.argv[0])))
Expand All @@ -29,29 +30,29 @@ def _ddp_init(index: int = ...):
ddp_model = DDP(model)

def test_ddp_init(self):
pjrt._run_multiprocess(self._ddp_init)
pjrt.run_multiprocess(self._ddp_init)

@absltest.skipIf(pjrt.device_type() == 'GPU',
@absltest.skipIf(xr.device_type() == 'GPU',
"GPU device is not supported by pjrt.spawn_threads")
def test_ddp_init_threaded(self):
pjrt.spawn_threads(self._ddp_init)

@parameterized.named_parameters(('small_net', False), ('large_net', True))
def test_ddp_correctness(self, use_large_net: bool):
pjrt._run_multiprocess(
pjrt.run_multiprocess(
util.ddp_correctness,
init_method='pjrt://',
use_large_net=use_large_net,
debug=FLAGS.debug)

@absltest.skipIf(pjrt.device_type() == 'TPU' and tpu.version() < 4,
@absltest.skipIf(xr.device_type() == 'TPU' and tpu.version() < 4,
"env:// doesn't support multithreading")
def test_ddp_correctness_env_init(self):
with mock.patch.dict(os.environ, {
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12355'
}):
pjrt._run_multiprocess(
pjrt.run_multiprocess(
util.ddp_correctness, use_large_net=False, debug=FLAGS.debug)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from absl.testing import absltest, parameterized
import torch_xla.core.xla_env_vars as xenv
from torch_xla.experimental import tpu
from torch_xla._internal import tpu

from unittest import mock

Expand Down
19 changes: 10 additions & 9 deletions test/pjrt/test_mesh_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
from torch_xla import runtime as xr
from torch_xla._internal import pjrt


class PjRtMeshServiceTest(parameterized.TestCase):
Expand All @@ -14,7 +15,7 @@ def _rendezvous_static_size():
return xm.rendezvous("test rendezvous", payload)

def test_rendezvous_static_size(self):
results = pjrt._run_multiprocess(self._rendezvous_static_size)
results = pjrt.run_multiprocess(self._rendezvous_static_size)

expected = sorted([b'message %d' % r for r in results])
self.assertDictEqual(results, {r: expected for r in results})
Expand All @@ -25,25 +26,25 @@ def _rendezvous_dynamic_size():
return xm.rendezvous("test rendezvous", payload)

def test_rendezvous_dynamic_size(self):
results = pjrt._run_multiprocess(self._rendezvous_dynamic_size)
results = pjrt.run_multiprocess(self._rendezvous_dynamic_size)

expected = sorted([b'message' * r for r in results])
self.assertDictEqual(results, {r: expected for r in results})

@staticmethod
def _rendezvous_replica_groups():
replicas = list(range(pjrt.global_device_count()))
replicas = list(range(xr.global_device_count()))
return xm.rendezvous("test rendezvous", b'message', replicas)

def test_rendezvous_replica_groups(self):
results = pjrt._run_multiprocess(self._rendezvous_replica_groups)
results = pjrt.run_multiprocess(self._rendezvous_replica_groups)

expected = [b'message'] * len(results)
self.assertDictEqual(results, {r: expected for r in results})

def test_rendezvous_empty_payload(self):
test_fn = functools.partial(xm.rendezvous, 'test rendezvous', b'')
results = pjrt._run_multiprocess(test_fn)
results = pjrt.run_multiprocess(test_fn)

expected = [b''] * len(results)
self.assertDictEqual(results, {r: expected for r in results})
Expand All @@ -55,7 +56,7 @@ def rendezvous_default_payload_cpu_transfers():
return met.counter_value('xla::_to_cpu')

def test_rendezvous_default_payload_cpu_transfers(self):
results = pjrt._run_multiprocess(
results = pjrt.run_multiprocess(
self.rendezvous_default_payload_cpu_transfers)

# Expect one CPU transfer: the max size of all payloads
Expand All @@ -66,14 +67,14 @@ def test_rendezvous_string_payload(self):
test_fn = functools.partial(xm.rendezvous, 'test rendezvous', "")

with self.assertRaises(TypeError):
pjrt._run_multiprocess(test_fn)
pjrt.run_multiprocess(test_fn)

@staticmethod
def _mesh_reduce():
return xm.mesh_reduce('test mesh reduce', xm.get_ordinal(), sum)

def test_mesh_reduce(self):
results = pjrt._run_multiprocess(self._mesh_reduce)
results = pjrt.run_multiprocess(self._mesh_reduce)
values = list(results.values())

expected = sum(range(len(values)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch_xla.debug.metrics as met
import torch_xla.debug.metrics_compare_utils as mcu
from absl.testing import absltest
from torch_xla.experimental import pjrt
from torch_xla import runtime as xr

EXPECTED_COMPUTATION_CLIENT_METRICS = [
"CompileTime",
Expand All @@ -21,7 +21,7 @@
class TestPjRtRuntimeMetrics(absltest.TestCase):

def setUp(self):
pjrt.set_device_type('CPU')
xr.set_device_type('CPU')

def test_metrics_report(self):
self.assertEmpty(met.metrics_report())
Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def _profile(logdir: str, port: int = 9012):
class TestPjRtProfiler(absltest.TestCase):

def setUp(self):
assert pjrt.using_pjrt()
assert xr.using_pjrt()

# HACK: ensure libtpu is loaded if using TPU
xm.xla_device()
Expand Down
20 changes: 10 additions & 10 deletions test/pjrt/test_experimental_pjrt.py → test/pjrt/test_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,25 @@
from absl.testing import absltest, parameterized
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
from torch_xla.experimental import pjrt
from torch_xla import runtime as xr


class TestExperimentalPjrt(parameterized.TestCase):

def setUp(self):
pjrt.set_device_type('CPU')
xr.set_device_type('CPU')

@parameterized.parameters(('CPU', 'CPU'), ('GPU', 'GPU'), ('TPU', 'TPU'),
('TPU_C_API', 'TPU'), ('TPU_LEGACY', 'TPU'))
def test_device_type(self, pjrt_device, expected):
with mock.patch.dict(os.environ, {'PJRT_DEVICE': pjrt_device}, clear=True):
self.assertEqual(pjrt.device_type(), expected)
self.assertEqual(xr.device_type(), expected)

def test_requires_pjrt(self):
with mock.patch.dict(
os.environ, {'PJRT_SELECT_DEFAULT_DEVICE': '0'}, clear=True):
with self.assertRaises(NotImplementedError):
pjrt.xla_device()
xr.xla_device()

def test_default_ordinals(self):
global_ordinal = xm.get_ordinal()
Expand All @@ -37,14 +37,14 @@ def test_default_ordinals(self):

def test_num_local_devices(self):
self.assertLen(xm.get_xla_supported_devices(),
pjrt.addressable_device_count())
xr.addressable_device_count())

def test_num_global_devices(self):
self.assertLen(torch_xla._XLAC._xla_get_all_devices(),
pjrt.global_device_count())
xr.global_device_count())

def test_world_size(self):
self.assertEqual(xm.xrt_world_size(), pjrt.world_size())
self.assertEqual(xm.xrt_world_size(), xr.world_size())

def test_xla_device_error(self):
with self.assertRaises(IndexError):
Expand Down Expand Up @@ -78,12 +78,12 @@ def test_pjrt_default_device(self, env_vars, expect_using_pjrt):

with logs_context:
# Configure default device
pjrt.using_pjrt()
xr.using_pjrt()

if expect_using_pjrt:
self.assertIn(pjrt.device_type(), ['CPU', 'GPU', 'TPU'])
self.assertIn(xr.device_type(), ['CPU', 'GPU', 'TPU'])
else:
self.assertIsNone(pjrt.device_type())
self.assertIsNone(xr.device_type())


if __name__ == '__main__':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import torch_xla.core.xla_env_vars as xenv
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt
from torch_xla import runtime as xr
from torch_xla._internal import pjrt
from absl.testing import absltest, parameterized


class TestExperimentalPjrtGpu(parameterized.TestCase):

def setUp(self):
pjrt.set_device_type('GPU')
xr.set_device_type('GPU')

os.environ.update({
xenv.PJRT_GPU_ASYNC_CLIENT: 'true',
Expand All @@ -29,27 +30,27 @@ def test_default_gpu_device(self):

num_devices = int(os.environ[xenv.GPU_NUM_DEVICES])
expected = {i: torch.device(f'xla:0') for i in range(num_devices)}
devices_per_process = pjrt._run_multiprocess(xm.xla_device)
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices_per_process, expected)

def test_multi_gpu_devices(self):
num_devices = int(os.environ[xenv.GPU_NUM_DEVICES])
expected = {i: torch.device(f'xla:0') for i in range(num_devices)}

devices_per_process = pjrt._run_multiprocess(xm.xla_device)
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices_per_process, expected)

@parameterized.named_parameters(('xla_model', xm.get_ordinal),
('pjrt', pjrt.global_ordinal))
('pjrt', xr.global_ordinal))
def test_global_ordinal(self, ordinal_func):
results = pjrt._run_multiprocess(ordinal_func)
results = pjrt.run_multiprocess(ordinal_func)
self.assertListEqual(sorted(results.values()), [0, 1, 2, 3])

@parameterized.named_parameters(('xla_model', xm.get_local_ordinal),
('pjrt', pjrt.local_ordinal))
('pjrt', xr.local_ordinal))
def test_local_ordinal(self, ordinal_func):
# TODO(wcromar): add multiprocess tests
results = pjrt._run_multiprocess(ordinal_func)
results = pjrt.run_multiprocess(ordinal_func)
self.assertListEqual(sorted(results.values()), [0, 1, 2, 3])

@staticmethod
Expand Down Expand Up @@ -90,7 +91,7 @@ def test_multi_gpu_backwards(self):
'device': f'xla:0'
} for i in range(4)
}
results = pjrt._run_multiprocess(self._multi_gpu_backwards)
results = pjrt.run_multiprocess(self._multi_gpu_backwards)

self.assertDictEqual(results, expected)

Expand All @@ -113,15 +114,15 @@ def _broadcast(sync):
device = xm.xla_device()
model = nn.Linear(5, 5).to(device)
if sync:
pjrt.broadcast_master_param(model)
xm.broadcast_master_param(model)

xm.mark_step()
return next(model.parameters()).detach().cpu().numpy()

@parameterized.named_parameters(('synchronized_parameters', True),
('unsynchronized_parameters', False))
def test_broadcast_master_param(self, sync):
results = pjrt._run_multiprocess(self._broadcast, sync)
results = pjrt.run_multiprocess(self._broadcast, sync)
master_params = results[0]
for ordinal, worker_params in results.items():
if sync:
Expand All @@ -141,7 +142,7 @@ def _all_gather(pin_layout):

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_gather(self, pin_layout):
results = pjrt._run_multiprocess(self._all_gather, pin_layout)
results = pjrt.run_multiprocess(self._all_gather, pin_layout)

expected = list(range(len(results)))
for v in results.values():
Expand All @@ -167,7 +168,7 @@ def _reduce_scatter(pin_layout):

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_reduce_scatter(self, pin_layout):
results = pjrt._run_multiprocess(self._reduce_scatter, pin_layout)
results = pjrt.run_multiprocess(self._reduce_scatter, pin_layout)

for ordinal, value in results.items():
np.testing.assert_array_equal(value, [-ordinal])
Expand Down Expand Up @@ -198,7 +199,7 @@ def _all_to_all(pin_layout):

@parameterized.named_parameters(('pinned', True), ('unpinned', False))
def test_all_to_all(self, pin_layout):
results = pjrt._run_multiprocess(self._all_to_all, pin_layout)
results = pjrt.run_multiprocess(self._all_to_all, pin_layout)

for ordinal, value in results.items():
np.testing.assert_array_equal(value, [[[-ordinal] * len(results),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
import torch_xla.core.xla_model as xm
import torch_xla.core.xla_env_vars as xenv
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.experimental import pjrt
from torch_xla import runtime as xr
from torch_xla._internal import pjrt


class TestExperimentalPjrtMultiCpu(parameterized.TestCase):

def setUp(self):
pjrt.set_device_type('CPU')
xr.set_device_type('CPU')

os.environ.update({
xenv.PJRT_CPU_ASYNC_CLIENT: 'true',
Expand All @@ -25,7 +26,7 @@ def test_default_cpu_device(self):
os.environ.pop(xenv.PJRT_CPU_ASYNC_CLIENT, None)

expected = {0: torch.device('xla:0')}
devices_per_process = pjrt._run_multiprocess(xm.xla_device)
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices_per_process, expected)

def test_multi_cpu_devices(self):
Expand All @@ -36,20 +37,20 @@ def test_multi_cpu_devices(self):
3: torch.device('xla:3'),
}

devices_per_process = pjrt._run_multiprocess(xm.xla_device)
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
self.assertDictEqual(devices_per_process, expected)

@parameterized.named_parameters(('xla_model', xm.get_ordinal),
('pjrt', pjrt.global_ordinal))
('pjrt', xr.global_ordinal))
def test_global_ordinal(self, ordinal_func):
results = pjrt._run_multiprocess(ordinal_func)
results = pjrt.run_multiprocess(ordinal_func)
self.assertListEqual(sorted(results.values()), [0, 1, 2, 3])

@parameterized.named_parameters(('xla_model', xm.get_local_ordinal),
('pjrt', pjrt.local_ordinal))
('pjrt', xr.local_ordinal))
def test_local_ordinal(self, ordinal_func):
# TODO(wcromar): add multiprocess tests
results = pjrt._run_multiprocess(ordinal_func)
results = pjrt.run_multiprocess(ordinal_func)
self.assertListEqual(sorted(results.values()), [0, 1, 2, 3])

@staticmethod
Expand Down Expand Up @@ -91,7 +92,7 @@ def test_multi_cpu_backwards(self):
'device': f'xla:{i}'
} for i in range(4)
}
results = pjrt._run_multiprocess(self._multi_cpu_backwards)
results = pjrt.run_multiprocess(self._multi_cpu_backwards)

self.assertDictEqual(results, expected)

Expand Down Expand Up @@ -119,7 +120,7 @@ def _hlo_dump(tmpdir: str):

def test_hlo_dump(self):
tmpdir = self.create_tempdir().full_path
pjrt._run_multiprocess(self._hlo_dump, tmpdir)
pjrt.run_multiprocess(self._hlo_dump, tmpdir)

files = os.listdir(tmpdir)
for i in range(4):
Expand Down
Loading