-
Notifications
You must be signed in to change notification settings - Fork 566
[SPMD] Support SPMDShardToFullShape #6925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
9c8643b
to
50f4c66
Compare
xt = xs._mark_manual_sharding(xt) | ||
xx = torch_xla._XLAC._spmd_shard_to_full_shape( | ||
xt.global_tensor, | ||
torch_xla._XLAC.OpSharding([], [], [], xs.ShardingType.REPLICATED), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So it uses the passed sharding type when returns back to SPMD?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup!
with self.assertRaises(RuntimeError): | ||
x = torch_xla._XLAC._spmd_shard_to_full_shape( | ||
x, torch_xla._XLAC.OpSharding([], [], [], xs.ShardingType.REPLICATED), | ||
x.shape, x.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so we've decided to pass the original full shape as argument, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At least I have decided haha, see the PR descriptions for the reasons behind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, left some questions -- thanks @alanwaketan ❤️
Thanks @yeounoh for the quick review! |
Summary: This pull request enables SPMDShardToFullShape. The trickiest part is how to get the full shape, and here is a couple of options: 1. Bookkeeping the shape full shape that enters SPMDFullToShardShape. This is not selected given the output could be created on the fly. 2. Constructing the full shape from the local shard and the sharding spec. This is not selected given there is no way to deal with the padding. We can't examine the data during the tracing time. 3. Let users pass the full shape in. This is selected because it's just the most sounded path. Tes Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_e2e -k test_spmd_shard_to_full_shape
Summary: This pull request enables SPMDShardToFullShape. The trickiest part is how to get the full shape, and here is a couple of options: 1. Bookkeeping the shape full shape that enters SPMDFullToShardShape. This is not selected given the output could be created on the fly. 2. Constructing the full shape from the local shard and the sharding spec. This is not selected given there is no way to deal with the padding. We can't examine the data during the tracing time. 3. Let users pass the full shape in. This is selected because it's just the most sounded path. Tes Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_e2e -k test_spmd_shard_to_full_shape
Summary:
This pull request enables SPMDShardToFullShape. The trickiest part is how to get the full shape, and here is a couple of options:
Tes Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_manual_sharding_e2e -k test_spmd_shard_to_full_shape