-
Notifications
You must be signed in to change notification settings - Fork 566
Define PJRT plugin interface in C++ #6360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2cafa8e
to
3d7f48f
Compare
} | ||
|
||
std::optional<PluginEntry> GetPjRtPlugin(const std::string& device_type) { | ||
std::shared_ptr<const PjRtPlugin> GetPjRtPlugin( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any particular reason for this optional -> shared_ptr change?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PjRtPlugin
has to be a pointer or a reference now that we don't know the concrete type (whereas PluginEntry
was just a value). My first thought was to make this an optional reference, but apparently C++ doesn't support that. An optional pointer would be unwieldy, because we'd have two layers of indirection/nullability (empty optional, and an optional holding nullptr
). So I'm just returning shared_ptr
here and letting nullptr
represent the empty value.
torch_xla._XLAC._register_pjrt_plugin(name, device_plugin) | ||
|
||
|
||
def register_installed_plugins(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pytorch/xla/plugins/cuda/README.md gives an example of how to use register_plugin
. I wonder under what circumstance do we use register_installed_plugins()
(I see it's used when we import torch_xla
. Does it mean if we set XLA_REGISTER_INSTALLED_PLUGINS, then users don't have to use register_plugin
anymore?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't expect users to ever call either of these themselves. As long as plugin authors set the entrypoint correctly in their package, registration will happen in the background. Otherwise, plugin authors may add register_plugin
to their module and run it on import (similar to how torch.distributed
backend registration works).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, I like the C++ to Python interface approach.
b4a01dd
to
0b96292
Compare
Love C++ interface too. |
PjRtPlugin
in C++ so the functionality can be called there. In general, implementations should still be in Python, inheriting fromDevicePlugin
.PyPjRtPlugin
is the trampoline class that allows python implementations of virtual functions.library_path
andclient_create_options
won't be invoked until client creation time, allowing the user to change settings after importingtorch_xla
.TpuPlugin
throws anEnvironmentError
upon registration.