Skip to content

[mxfp8] Skip weight gradient for frozen weights in mx_mm backward#4471

Open
ultism wants to merge 1 commit into
pytorch:mainfrom
ultism:perf/mxfp8-skip-wgrad-frozen-weight
Open

[mxfp8] Skip weight gradient for frozen weights in mx_mm backward#4471
ultism wants to merge 1 commit into
pytorch:mainfrom
ultism:perf/mxfp8-skip-wgrad-frozen-weight

Conversation

@ultism

@ultism ultism commented Jun 9, 2026

Copy link
Copy Markdown
Contributor

Summary

When the weight does not require grad — a frozen base in LoRA or frozen-layer finetuning — autograd discards grad_weight, so computing it in mx_mm.backward (the wgrad GEMM plus its two dim1 MXFP8 casts) is pure overhead.

This guards the weight-gradient computation on ctx.needs_input_grad[1] and returns grad_weight=None for a frozen weight. It's the same idiom already used by the int8_mixed_precision and bitnet training linears.

# input_t @ grad_output = grad_weight
if not ctx.needs_input_grad[1]:
    grad_weight = None
elif wgrad_with_hp:
    ...

Eager vs compile

This is an eager-mode win. Under torch.compile, AOTAutograd already specializes the backward on requires_grad and prunes the unused grad_weight subgraph via DCE, so the compiled frozen path is already optimal; this change just brings eager to parity. It does not change compiled behavior, and ctx.needs_input_grad is a Python bool (a trace-time constant), so the branch does not introduce a graph break.

Measured on an RTX 5060 Ti (sm120), frozen vs trainable fwd+bwd for a frozen linear, ratio = trainable / frozen (>1 means the wgrad work is skipped). Absolute numbers vary run-to-run with GPU clocks; the within-run ratio is the signal:

path shape (M, K, N) eager (this PR) eager (main) compile (main)
LLaMA MLP up 8192, 4096, 14336 1.54x 1.01x 1.52x
LLaMA MLP dn 8192, 14336, 4096 1.57x 1.00x 1.24x

So on main eager does not skip the wgrad (≈1.0x) while compile already does (≈1.2–1.5x); this PR closes the eager gap.

Test

test_mxfp8_linear_frozen_weight_skips_wgrad (EMULATED + AUTO): a frozen weight gets grad_weight=None, a trainable one does not, grad_input is bit-identical between the two, and the frozen path performs strictly fewer dim1 casts (verified by counting _to_mxfp8_dim1_kernel_wrapper calls). This fails on main (the frozen path still runs the wgrad casts) and passes here.

test_mxfp8_linear_frozen_weight_compile (EMULATED + AUTO): compiles the frozen module with fullgraph=True and checks the result matches eager and weight.grad is None, locking in compile-safety.

Verified locally on sm120: the full test_mxfp8_linear.py suite passes; ruff check and ruff format --check clean.

Context

Frozen-base / parallel-adapter (LoRA) training on a quantized base is the motivating use case; see #4376 for the analogous float8 discussion and #4022 for the MXFP8 training roadmap. Touches the same mx_mm.backward as #4470 (a non-contiguous grad_output fix) but is independent of it.

🤖 Generated with Claude Code

When the weight does not require grad (e.g. a frozen base in LoRA or
frozen-layer finetuning), grad_weight is discarded by autograd, so the
wgrad GEMM and its two dim1 MXFP8 casts are pure overhead. Guard the
weight-gradient computation on ctx.needs_input_grad[1] and return
grad_weight=None when the weight is frozen. Same idiom already used by
the int8_mixed_precision and bitnet training linears.

This is an eager-mode win (~1.5x on a frozen LLaMA-MLP-sized linear
fwd+bwd). Under torch.compile, AOTAutograd already prunes the unused
grad_weight via dead-code elimination, so this brings the eager path to
parity without changing compiled behavior. The branch tests a Python
bool (ctx.needs_input_grad), i.e. a trace-time constant, so it does not
introduce a graph break.

Tests: a frozen weight gets grad_weight=None with grad_input unchanged
and the wgrad dim1 casts skipped; a torch.compile(fullgraph=True) case
checks the frozen backward stays graph-break-free and matches eager.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@pytorch-bot

pytorch-bot Bot commented Jun 9, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4471

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 2 Unrelated Failures

As of commit eb6884b with merge base 5165bfb (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 9, 2026
@vkuzo

vkuzo commented Jun 10, 2026

Copy link
Copy Markdown
Contributor

@claude review

@vkuzo vkuzo added the module: training quantize_ api training flow label Jun 10, 2026
@vkuzo

vkuzo commented Jun 11, 2026

Copy link
Copy Markdown
Contributor

lgtm and CI failures look unrelated. @ultism can you rebase, your other PR which we just merged has a conflict with this one. Thank you!

@andrewor14 can you help land this after I go OOO please? Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: training quantize_ api training flow

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants