Factor out dispatch and layout registration table#360
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/360
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 8929358 with merge base 8841094 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
8cbd7a0 to
8e41a62
Compare
8e41a62 to
0b1c4bb
Compare
| import functools | ||
|
|
||
| # torch_function and torch_dispatch operator dispatch registrations | ||
| _ATEN_OP_OR_TORCH_FN_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) |
There was a problem hiding this comment.
Nit: is this typing right?
There was a problem hiding this comment.
yeah, but I can change Any to Callable as well to be more specific I think
| def wrapper(*args, **kwargs): | ||
| return func(*args, **kwargs) | ||
|
|
||
| _ATEN_OP_OR_TORCH_FN_TABLE[cls][op] = wrapper |
There was a problem hiding this comment.
Part of me still doesn't like having the function table be shared among subclasses.... I don't have a great reason against though lol so that's not blocking
There was a problem hiding this comment.
Actually curious to hear more. I want us to get the aesthetics right
There was a problem hiding this comment.
I actually dont have any logically good reason lol it is more just gut check I dont love global things that tie a bunch of potentially disperate ideas/implementations together
The dispatcher in PyTorch is a giant global vtable that ties everything together so there are obvious counter points thats why this super nit and not blocking
There was a problem hiding this comment.
OK yeah, I think we can follow what pytorch is already doing
There was a problem hiding this comment.
This breaks QLora in TorchTune. I plannning to add a FSDP unit test in TorchAO. Probably need to revert NF4 changes. #380
[rank0]: File "/data/users/weif/pytorch-official/pytorch/torch/distributed/_tensor/dispatch.py", line 205, in dispatch
[rank0]: local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank0]: File "/data/users/weif/pytorch-official/pytorch/torch/_ops.py", line 666, in __call__
[rank0]: return self_._op(*args, **kwargs)
[rank0]: File "/data/users/weif/ao/torchao/dtypes/nf4tensor.py", line 541, in __torch_function__
[rank0]: return _ATEN_OP_OR_TORCH_FN_TABLE[cls][func](*args, **kwargs)
[rank0]: File "/data/users/weif/ao/torchao/dtypes/utils.py", line 25, in wrapper
[rank0]: return func(*args, **kwargs)
[rank0]: TypeError: nf4_detach() missing 1 required positional argument: 'args'
Summary: att, after the refactor we can use common utils for new dtypes as well Test Plan: python test/dtypes/test_nf4.py python test/dtypes/test_aqt.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
0b1c4bb to
8929358
Compare
Summary: att, after the refactor we can use common utils for new dtypes as well Test Plan: python test/dtypes/test_nf4.py python test/dtypes/test_aqt.py python test/integration/test_integration.py Reviewers: Subscribers: Tasks: Tags:
Summary:
att, after the refactor we can use common utils for new dtypes as well
Test Plan:
python test/dtypes/test_nf4.py
python test/dtypes/test_aqt.py
python test/integration/test_integration.py
Reviewers:
Subscribers:
Tasks:
Tags: