Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cudagraphed compiled module slows down when called recursively #138949

Open
vmoens opened this issue Oct 25, 2024 · 4 comments
Open

cudagraphed compiled module slows down when called recursively #138949

vmoens opened this issue Oct 25, 2024 · 4 comments
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vmoens
Copy link
Contributor

vmoens commented Oct 25, 2024

🐛 Describe the bug

In RL, we'd like to compile a policy that computes a value that may be used later on as an input to the next call of that policy.

It seems that compile with cudagraph has some unexpected behavior if we run a module and detach the output (as opposed to detaching all the weights from the module).

The following code shows a silly implementation of such a loop:

import torch
import tqdm
from torchrl.modules import MLP

from tensordict.nn import CudaGraphModule

mode = "reduce-overhead"

REDUCE_OVERHEAD = True
CG_MODULE = False
CLONE = False
DETACH = False

policy = MLP(512, out_features=512, num_cells=[512, 512, 512], device="cuda")
if DETACH:
    policy.requires_grad_(False)

obs = torch.randn(1, 512, device="cuda")

if REDUCE_OVERHEAD:
    policy = torch.compile(policy, mode=mode, fullgraph=True)
elif CG_MODULE:
    policy = CudaGraphModule(torch.compile(policy))
else:
    policy = torch.compile(policy)

for _ in tqdm.tqdm(range(1000_000)):
    obs = policy(obs).detach()
    if CLONE:
        obs = obs.clone()

Now some experiments on H100:

  • REDUCE_OVERHEAD = False, CG_MODULE = False, CLONE = False: about 6500 iter/sec, steady
  • REDUCE_OVERHEAD = False, CG_MODULE = False, CLONE = True: about 5800 iter/sec, steady
  • REDUCE_OVERHEAD = False, CG_MODULE = True, CLONE = False: about 25000 iter/sec, steady
  • REDUCE_OVERHEAD = False, CG_MODULE = True, CLONE = True: about 22600 iter/sec, steady

Now with reduce overhead:

  • REDUCE_OVERHEAD = False, CLONE = False: about 320 iter/sec initially, quickly decreases to 50 iter/sec and lower
  • REDUCE_OVERHEAD = False, CLONE = True: about 4.3 iter/sec, steady

This is not observed if DETACH=True so it seems this is related to the way graphs are built within the compiled code.

Versions

nightly

cc @mcarilli @ezyang @eellison @penguinwu @chauhang @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov

@ezyang
Copy link
Contributor

ezyang commented Oct 29, 2024

@eellison, probably cuda graph tree related?

@vmoens
Copy link
Contributor Author

vmoens commented Oct 29, 2024

We figured out that torch.compiler.cudagraph_mark_step_begin() solved it.
The reason is simply that the graph is not released even though the value is detached. In eager, that wouldn't obviously be an issue.
Since in other cases you're explicitly asked to use torch.compiler.cudagraph_mark_step_begin when required, I assume it was always the case.
By the way, we've found another interesting issue which is that cudagraph with models that are not detached + mujoco rendering on-device causes a similar issue (very slow calls to the model), again without warning. This is again due to the fact that the graphs pile up in memory and conflict with mujoco until an OOM occurs. (We tried to come up with a simple minimal reprod but didn't manage to do so)

I'm mentioning this because in two unrelated cases we faced perf issues with "reduce-overhead" that had the same solution (torch.compiler.cudagraph_mark_step_begin()) but no warning and quite a different underlying cause.

I'm curious: what would an automated call to torch.compiler.cudagraph_mark_step_begin() cause when entering the compiled region, unless a certain flag is passed when calling compile? Am I correct that the usages that would (silently) break would be like

c1 = torch.compile(c1, mode="reduce-overhead")
c2 = torch.compile(c2, mode="reduce-overhead")

y = c1(x) # calls torch.compiler.cudagraph_mark_step_begin() automatically
z = c2(x, y) # calls torch.compiler.cudagraph_mark_step_begin() automatically
z.sum().backward() # does not propagate through c1 weights

The reason I'm asking is that if cudagraph conflicts with other libs (like mujoco) it might be hard to raise a warning consistently. Also, in a code like TDMPC2, we have several calls to cudagraph_mark_step_begin all across the codebase and it's a bit cumbersome to do that (esp givent that most users will not really understand what that is doing precisely).

One option could be

>>> c1 = torch.compile(c1, mode="reduce-overhead")
>>> y0 = c1(x)
>>> y1 = c1(x)
Warning: two consecutive calls of cudagraphed models have been made while the `free_tensor` argument has not been passed to the compiler. If you want to keep the graph in between iterations, please call compile with `free_tensor=False`. To silence this warning and keep the current behaviour pass `free_tensor=True`.

@desertfire desertfire added module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: inductor triage review labels Oct 29, 2024
@eellison
Copy link
Contributor

To state the current behavior:

There are two stages in cudagraphs, recording, and replaying. If you end a recording early it will manifest as an error, if you end a recording late it will manifest as being slow.

When we end a recording is controlled here:

def can_start_new_generation(self) -> bool:
if not self.in_new_torch_compile_invocation():
return False
if self.user_invoked_mark_step():
return True
return not self.running_forwards_with_pending_backwards

I think what you are hitting is the final case - running_forwards_with_pending_backwards. Is the code you have above realistic ? When would you be running a forward for 1 million runs, accumulating a gradient, but not calling backward ? I'm surprised this is OOMing already. (or maybe I am misreading).

There is also this code that detects for when we should issue a warning:

def check_warn_on_unable_to_start_executing(self, function_id: FunctionID) -> None:
, see example here.

I think the bug here is that for some reason the warning did not fire.

I'm curious: what would an automated call to torch.compiler.cudagraph_mark_step_begin()

We have this to some extent already, in that we detect when torch.compile is invoked and help use that to determine when to end a recording. But we are not willing to break the case of multiple, separate torch.compile wrappings of submodules in training. This is extremely common.

>>> c1 = torch.compile(c1, mode="reduce-overhead")
>>> y0 = c1(x)
>>> y1 = c1(x)
Warning: two consecutive calls of cudagraphed models have been made while the `free_tensor` argument has not been passed to the compiler. If you want to keep the graph in between iterations, please call compile with `free_tensor=False`. To silence this warning and keep the current behaviour pass `free_tensor=True`.

Today, the way this is handled in inference, is if you access y0 we'll give an informative error message. We did warning spam at one point but it tends to get ignored.

In training - we would expected y0 to participate in backward somehow.

@desertfire desertfire added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Oct 29, 2024
@ezyang
Copy link
Contributor

ezyang commented Oct 30, 2024

We talked about potentially marking step in set_stance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cuda graphs Ability to capture and then replay streams of CUDA kernels module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants