-
Notifications
You must be signed in to change notification settings - Fork 566
[SPMD] Support SPMDFullToShardShape #6922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
return xtensors; | ||
} | ||
|
||
bool IsIr(const at::Tensor& tensor) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit, IsNonDeviceDataIR
# It looks like XLA does't like only having manual sharding in the HLO. | ||
# It needs to be paired with SPMDFullToShardShape/SPMDShardToFullShape. | ||
# The following exception cannot be caught somehow. | ||
# xx.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you intend to keep this xx.cpu
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, it's more like a note that this won't work... I was trying to use with self.assertRaises but that doesn't capture the exception... I have noticed this before too. When libtpu crashed, it's hard to catch it in the py level. Not sure why. Maybe you have some better ideas?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh I think I run into similar issue before.. The way I handle it was ugly through
xla/test/spmd/test_dynamo_spmd.py
Lines 172 to 181 in a7a1357
# crash will hapeen in a async execution thread, need to grab the lock again to | |
# surface that exception | |
dynamo_res = dynamo_linear(xla_x) | |
try: | |
print(dynamo_res) | |
except: | |
print('catch') | |
# it is hard to catch the C++ runtime error in python, instead we can check if | |
# after printing that dynamo_res is still a placeholder then it means C++ crashed. | |
self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C++ crash on pt level can be caught with self.assertRaise
but not libtpu level.... I'm not sure why... yea, not even with this hack...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @will-cromar Do you know how to catch libtpu exception on py? Appreciate your insights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you can. To make a proper runtime error, you have to raise an exception, and Google internal binaries don't generally do that. I wrote about a similar case in #6700 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Will. That makes a lot of sense now.
torch_xla/csrc/aten_xla_type.cpp
Outdated
tensor_methods::custom_sharding_(output_tensor, | ||
input_tensor->sharding_spec()); | ||
input_tensor->sharding_spec(), | ||
CustomSharding::Type::kSharding); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so you assume only tensor with kSharding
will be called with in place ops?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the original design which is to align with the original design of SPMD... So yea.. for kSharding...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make kSharding
to be default then? This way most people reading this code won't need to figure out what kSharding
actually means.
enum class Type { | ||
kSharding, | ||
kSPMDFullToShardShape, | ||
kSPMDShardToFullShape, | ||
}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This enum is really confusing, can you add some comment around what they actually does? I was reading the SPMD code again, this op itself only means we want to shard the underlying value and the actual sharding resides in the XlaTensor or Based XLAIR object?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, this is just the name of the custom call. The sharding annotation is in XlaTensor as normal. I can add more explanations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can annotate explicilty that this is sharding type for custom call in the enum class name or somethinhg.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess the current approach sort of does it already? Can you be more specific? @yeounoh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, Type
is already defined under CustomSharding
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
Thanks, Yeounoh! |
Summary: This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops. To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape
Summary: This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops. To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation. Test Plan: PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape
Summary:
This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops.
To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation.
Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape