Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft-mode export: ep.run_decompositions() doesn't run with real tensor prop #139283

Open
zou3519 opened this issue Oct 30, 2024 · 0 comments
Open
Labels
export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step oncall: export oncall: pt2

Comments

@zou3519
Copy link
Contributor

zou3519 commented Oct 30, 2024

Repro: patch in #139213 (needed for an error to show up), then run the following script:

import torch
import torch._functorch.config

@torch.library.custom_op("export::foo", mutates_args={}) # E: Untyped decorator makes fun
def foo(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    return x * y

class Foo(torch.nn.Module):
    def forward(self, x, y): # E: Function is missing a type annotation  [no-untyped-def]
        return foo(x, y)

model = Foo()
inputs = (torch.randn(1, 3), torch.randn(2, 1))
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
    ep = torch.export.export_for_training(model, inputs)
    nodes = list(ep.module().graph.nodes)
    ep.run_decompositions({})

What's going on is that:

  • the included PR allows real_tensor_prop to infer a meta kernel for an operator during export
  • the exported program does not include real tensors along with the FakeTensors (should it?)
  • run_decompositions does some re-tracing without propagate_real_tensors because there are no real tensors
  • run_decompositions errors due to no meta (the inferred meta kernel isn't persistent)

A fix could be that "if you're not decomposing the operator, then run_decompositions uses the existing FakeTensors to infer a meta kernel". But there's a more general question of if there are any more downstream graph passes that will have problems due to not having access to the real tensors.

cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@pianpwk pianpwk assigned zou3519 and unassigned zou3519 Oct 30, 2024
@yushangdi yushangdi added the export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step label Oct 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
export-triaged This tag is used to tag issues that have been looked by PT2 Export team and determined the next step oncall: export oncall: pt2
Projects
None yet
Development

No branches or pull requests

3 participants