Skip to content

Conversation

alanwaketan
Copy link
Collaborator

Summary:
This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops.

To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape

@alanwaketan alanwaketan requested review from yeounoh and jonb377 April 12, 2024 21:34
return xtensors;
}

bool IsIr(const at::Tensor& tensor) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, IsNonDeviceDataIR

# It looks like XLA does't like only having manual sharding in the HLO.
# It needs to be paired with SPMDFullToShardShape/SPMDShardToFullShape.
# The following exception cannot be caught somehow.
# xx.cpu()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you intend to keep this xx.cpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's more like a note that this won't work... I was trying to use with self.assertRaises but that doesn't capture the exception... I have noticed this before too. When libtpu crashed, it's hard to catch it in the py level. Not sure why. Maybe you have some better ideas?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh I think I run into similar issue before.. The way I handle it was ugly through

# crash will hapeen in a async execution thread, need to grab the lock again to
# surface that exception
dynamo_res = dynamo_linear(xla_x)
try:
print(dynamo_res)
except:
print('catch')
# it is hard to catch the C++ runtime error in python, instead we can check if
# after printing that dynamo_res is still a placeholder then it means C++ crashed.
self.assertTrue(torch_xla._XLAC._is_placecholder(dynamo_res))

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

C++ crash on pt level can be caught with self.assertRaise but not libtpu level.... I'm not sure why... yea, not even with this hack...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @will-cromar Do you know how to catch libtpu exception on py? Appreciate your insights.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you can. To make a proper runtime error, you have to raise an exception, and Google internal binaries don't generally do that. I wrote about a similar case in #6700 (comment)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Will. That makes a lot of sense now.

tensor_methods::custom_sharding_(output_tensor,
input_tensor->sharding_spec());
input_tensor->sharding_spec(),
CustomSharding::Type::kSharding);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so you assume only tensor with kSharding will be called with in place ops?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the original design which is to align with the original design of SPMD... So yea.. for kSharding...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we make kSharding to be default then? This way most people reading this code won't need to figure out what kSharding actually means.

Comment on lines +10 to +14
enum class Type {
kSharding,
kSPMDFullToShardShape,
kSPMDShardToFullShape,
};
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This enum is really confusing, can you add some comment around what they actually does? I was reading the SPMD code again, this op itself only means we want to shard the underlying value and the actual sharding resides in the XlaTensor or Based XLAIR object?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, this is just the name of the custom call. The sharding annotation is in XlaTensor as normal. I can add more explanations.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can annotate explicilty that this is sharding type for custom call in the enum class name or somethinhg.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the current approach sort of does it already? Can you be more specific? @yeounoh

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, Type is already defined under CustomSharding

Copy link
Contributor

@yeounoh yeounoh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@alanwaketan
Copy link
Collaborator Author

Thanks, Yeounoh!

@alanwaketan alanwaketan merged commit 2763248 into master Apr 15, 2024
lausannel pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary:
This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops.

To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape
baoleai pushed a commit to AlibabaPAI/xla that referenced this pull request Aug 6, 2024
Summary:
This pull request supports SPMDFullToShardShape which is a custom op that opens a region for non-partitioned graph in SPMD program. It will stop SPMD auto sharding and partition in that region and therefore allows manual sharding like cc ops.

To implement it, this pull request expands CustomSharding node to accept a new type. To be notice, the output shape of the op needs to be the shard shape of the input, and the node needs to have manual sharding annotation.

Test Plan:
PJRT_DEVICE=TPU python test/spmd/test_xla_sharding.py -v -k test_spmd_full_to_shard_shape
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants