-
Notifications
You must be signed in to change notification settings - Fork 546
Move PJRT Python APIs out of torch_xla.experimental.*
#5011
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
torch_xla.experimental.*
torch_xla.experimental.*
@@ -11,14 +11,15 @@ | |||
import torch_xla.core.xla_env_vars as xenv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we renamed these experimental
files?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, good catch.
# TODO(wcromar): Detect GPU device too | ||
|
||
|
||
def device_type() -> Optional[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we have similar functions in torch_xla.core.xla_model
, do we want to do some clean up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I want to take this chance to do some clean up. I am always confuse what function to call for local ordinal
, gloabal ordinal
, worla_size
etc and what do they really mean in a pod context. If we can restructure those api a bit and maybe have those apis in this runtime instead that would be nice...(random idea, might need more thinking)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xla_model
becomes a kitch sink and we put random things in it, if we can move all runtime related bits in this file it is actually nicer..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Discussed this offline. We'll start to move APIs that interact directly with the runtime to a new module, and leave any modeling-related APIs in xla_model. I moved the PJRT version of rendezvous back to xla_model, and the old rendezvous
is will be an alias of that implementation when PJRT is enabled.
8fa857f
to
660a4cf
Compare
93def62
to
de2f117
Compare
torch_xla/runtime.py
Outdated
return | ||
|
||
logging.warning( | ||
'XRT configuration not detected. Defaulting to preview PJRT ' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to change preview to stable here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, but prefer to merge on Monday
dc1130e
to
f11be83
Compare
Reorganize the experimental PJRT Python APIs.
_internal
module for APIs that are well-tested, but likely to change. I moved device-specific logic here, since I expect to rework it in the near future. All of these functions are mainly used for framework development. In general, users shouldn't have to call them directly.torch.runtime
module.deprecation
module to register deprecated aliases for all public functions that are moving out into other parts of oftorch_xla
.Summary of new modules:
torch_xla.runtime
torch_xla._internal.tpu
torch_xla._internal.gpu
torch_xla._internal.pjrt