Skip to content

Factor out dispatch and layout registration table#360

Merged
jerryzh168 merged 1 commit into
pytorch:mainfrom
jerryzh168:refactor-utils
Jun 14, 2024
Merged

Factor out dispatch and layout registration table#360
jerryzh168 merged 1 commit into
pytorch:mainfrom
jerryzh168:refactor-utils

Conversation

@jerryzh168

@jerryzh168 jerryzh168 commented Jun 14, 2024

Copy link
Copy Markdown
Contributor

Summary:
att, after the refactor we can use common utils for new dtypes as well

# added for adding ops to op dispatch table, works for both aten and torch function
def _implements(cls, aten_ops_or_torch_fns):
   ...

# added for registering new layout class
def _register_layout_cls(cls, extended_layout: str):
 ...

# get layout tensor constructor
def _get_layout_tensor_constructor(cls, extended_layout: str) -> Callable:
 ...

Test Plan:
python test/dtypes/test_nf4.py
python test/dtypes/test_aqt.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:

@pytorch-bot

pytorch-bot Bot commented Jun 14, 2024

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/360

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 8929358 with merge base 8841094 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 14, 2024
@jerryzh168 jerryzh168 requested review from drisspg and msaroufim June 14, 2024 05:22
Comment thread torchao/dtypes/utils.py Outdated
import functools

# torch_function and torch_dispatch operator dispatch registrations
_ATEN_OP_OR_TORCH_FN_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: is this typing right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, but I can change Any to Callable as well to be more specific I think

Comment thread torchao/dtypes/utils.py
def wrapper(*args, **kwargs):
return func(*args, **kwargs)

_ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Part of me still doesn't like having the function table be shared among subclasses.... I don't have a great reason against though lol so that's not blocking

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually curious to hear more. I want us to get the aesthetics right

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually dont have any logically good reason lol it is more just gut check I dont love global things that tie a bunch of potentially disperate ideas/implementations together

The dispatcher in PyTorch is a giant global vtable that ties everything together so there are obvious counter points thats why this super nit and not blocking

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK yeah, I think we can follow what pytorch is already doing

@weifengpy weifengpy Jun 18, 2024

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks QLora in TorchTune. I plannning to add a FSDP unit test in TorchAO. Probably need to revert NF4 changes. #380

[rank0]:   File "/data/users/weif/pytorch-official/pytorch/torch/distributed/_tensor/dispatch.py", line 205, in dispatch
[rank0]:     local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank0]:   File "/data/users/weif/pytorch-official/pytorch/torch/_ops.py", line 666, in __call__
[rank0]:     return self_._op(*args, **kwargs)
[rank0]:   File "/data/users/weif/ao/torchao/dtypes/nf4tensor.py", line 541, in __torch_function__
[rank0]:     return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)
[rank0]:   File "/data/users/weif/ao/torchao/dtypes/utils.py", line 25, in wrapper
[rank0]:     return func(*args, **kwargs)
[rank0]: TypeError: nf4_detach() missing 1 required positional argument: 'args'

Summary:
att, after the refactor we can use common utils for new dtypes as well

Test Plan:
python test/dtypes/test_nf4.py
python test/dtypes/test_aqt.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
@jerryzh168 jerryzh168 merged commit ca19e23 into pytorch:main Jun 14, 2024
@jerryzh168 jerryzh168 deleted the refactor-utils branch June 14, 2024 17:21
NicoleMayer pushed a commit that referenced this pull request Jul 9, 2024
Summary:
att, after the refactor we can use common utils for new dtypes as well

Test Plan:
python test/dtypes/test_nf4.py
python test/dtypes/test_aqt.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
yanbing-j pushed a commit to yanbing-j/ao that referenced this pull request Dec 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants