Skip to content

Conversation

will-cromar
Copy link
Collaborator

@will-cromar will-cromar commented Jan 23, 2024

  • Define base interface PjRtPlugin in C++ so the functionality can be called there. In general, implementations should still be in Python, inheriting from DevicePlugin.
    • PyPjRtPlugin is the trampoline class that allows python implementations of virtual functions.
  • library_path and client_create_options won't be invoked until client creation time, allowing the user to change settings after importing torch_xla.
    • Fixes bug where TpuPlugin throws an EnvironmentError upon registration.
  • It should be possible to merge the two maps of names to PJRT plugins, but I'll need to sort out some issues of lifetimes and ownership first. (e.g. if the Python reference goes out of scope, the remaining C++ object behaves incorrectly)

}

std::optional<PluginEntry> GetPjRtPlugin(const std::string& device_type) {
std::shared_ptr<const PjRtPlugin> GetPjRtPlugin(
Copy link
Collaborator

@JackCaoG JackCaoG Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any particular reason for this optional -> shared_ptr change?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PjRtPlugin has to be a pointer or a reference now that we don't know the concrete type (whereas PluginEntry was just a value). My first thought was to make this an optional reference, but apparently C++ doesn't support that. An optional pointer would be unwieldy, because we'd have two layers of indirection/nullability (empty optional, and an optional holding nullptr). So I'm just returning shared_ptr here and letting nullptr represent the empty value.

torch_xla._XLAC._register_pjrt_plugin(name, device_plugin)


def register_installed_plugins():
Copy link
Collaborator

@vanbasten23 vanbasten23 Jan 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pytorch/xla/plugins/cuda/README.md gives an example of how to use register_plugin. I wonder under what circumstance do we use register_installed_plugins() (I see it's used when we import torch_xla. Does it mean if we set XLA_REGISTER_INSTALLED_PLUGINS, then users don't have to use register_plugin anymore?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't expect users to ever call either of these themselves. As long as plugin authors set the entrypoint correctly in their package, registration will happen in the background. Otherwise, plugin authors may add register_plugin to their module and run it on import (similar to how torch.distributed backend registration works).

Copy link
Collaborator

@jonb377 jonb377 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, I like the C++ to Python interface approach.

Base automatically changed from wcromar/configure-plugin to master January 25, 2024 21:13
@will-cromar will-cromar force-pushed the wcromar/cpp-plugin-interface branch from b4a01dd to 0b96292 Compare January 25, 2024 21:25
@will-cromar will-cromar marked this pull request as ready for review January 25, 2024 21:48
@engineer1109
Copy link

Love C++ interface too.

bhavya01 pushed a commit that referenced this pull request Apr 22, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants