-
Notifications
You must be signed in to change notification settings - Fork 22.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
cudagraphed compiled module slows down when called recursively #138949
Comments
@eellison, probably cuda graph tree related? |
We figured out that I'm mentioning this because in two unrelated cases we faced perf issues with "reduce-overhead" that had the same solution ( I'm curious: what would an automated call to
The reason I'm asking is that if cudagraph conflicts with other libs (like mujoco) it might be hard to raise a warning consistently. Also, in a code like TDMPC2, we have several calls to One option could be
|
To state the current behavior: There are two stages in cudagraphs, recording, and replaying. If you end a recording early it will manifest as an error, if you end a recording late it will manifest as being slow. When we end a recording is controlled here: pytorch/torch/_inductor/cudagraph_trees.py Lines 2263 to 2270 in 9af1816
I think what you are hitting is the final case - There is also this code that detects for when we should issue a warning: pytorch/torch/_inductor/cudagraph_trees.py Line 2327 in 9af1816
I think the bug here is that for some reason the warning did not fire.
We have this to some extent already, in that we detect when torch.compile is invoked and help use that to determine when to end a recording. But we are not willing to break the case of multiple, separate torch.compile wrappings of submodules in training. This is extremely common.
Today, the way this is handled in inference, is if you access y0 we'll give an informative error message. We did warning spam at one point but it tends to get ignored. In training - we would expected y0 to participate in backward somehow. |
We talked about potentially marking step in set_stance |
🐛 Describe the bug
In RL, we'd like to compile a policy that computes a value that may be used later on as an input to the next call of that policy.
It seems that compile with cudagraph has some unexpected behavior if we run a module and detach the output (as opposed to detaching all the weights from the module).
The following code shows a silly implementation of such a loop:
Now some experiments on H100:
Now with reduce overhead:
This is not observed if DETACH=True so it seems this is related to the way graphs are built within the compiled code.
Versions
nightly
cc @mcarilli @ezyang @eellison @penguinwu @chauhang @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov
The text was updated successfully, but these errors were encountered: