[mxfp8] Skip weight gradient for frozen weights in mx_mm backward#4471
Open
ultism wants to merge 1 commit into
Open
[mxfp8] Skip weight gradient for frozen weights in mx_mm backward#4471ultism wants to merge 1 commit into
ultism wants to merge 1 commit into
Conversation
When the weight does not require grad (e.g. a frozen base in LoRA or frozen-layer finetuning), grad_weight is discarded by autograd, so the wgrad GEMM and its two dim1 MXFP8 casts are pure overhead. Guard the weight-gradient computation on ctx.needs_input_grad[1] and return grad_weight=None when the weight is frozen. Same idiom already used by the int8_mixed_precision and bitnet training linears. This is an eager-mode win (~1.5x on a frozen LLaMA-MLP-sized linear fwd+bwd). Under torch.compile, AOTAutograd already prunes the unused grad_weight via dead-code elimination, so this brings the eager path to parity without changing compiled behavior. The branch tests a Python bool (ctx.needs_input_grad), i.e. a trace-time constant, so it does not introduce a graph break. Tests: a frozen weight gets grad_weight=None with grad_input unchanged and the wgrad dim1 casts skipped; a torch.compile(fullgraph=True) case checks the frozen backward stays graph-break-free and matches eager. Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4471
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 2 Unrelated FailuresAs of commit eb6884b with merge base 5165bfb ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Contributor
|
@claude review |
Contributor
|
lgtm and CI failures look unrelated. @ultism can you rebase, your other PR which we just merged has a conflict with this one. Thank you! @andrewor14 can you help land this after I go OOO please? Thank you! |
vkuzo
approved these changes
Jun 11, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
When the weight does not require grad — a frozen base in LoRA or frozen-layer finetuning — autograd discards
grad_weight, so computing it inmx_mm.backward(the wgrad GEMM plus its twodim1MXFP8 casts) is pure overhead.This guards the weight-gradient computation on
ctx.needs_input_grad[1]and returnsgrad_weight=Nonefor a frozen weight. It's the same idiom already used by theint8_mixed_precisionandbitnettraining linears.Eager vs compile
This is an eager-mode win. Under
torch.compile, AOTAutograd already specializes the backward onrequires_gradand prunes the unusedgrad_weightsubgraph via DCE, so the compiled frozen path is already optimal; this change just brings eager to parity. It does not change compiled behavior, andctx.needs_input_gradis a Python bool (a trace-time constant), so the branch does not introduce a graph break.Measured on an RTX 5060 Ti (sm120), frozen vs trainable
fwd+bwdfor a frozen linear, ratio = trainable / frozen (>1 means the wgrad work is skipped). Absolute numbers vary run-to-run with GPU clocks; the within-run ratio is the signal:So on
maineager does not skip the wgrad (≈1.0x) while compile already does (≈1.2–1.5x); this PR closes the eager gap.Test
test_mxfp8_linear_frozen_weight_skips_wgrad(EMULATED + AUTO): a frozen weight getsgrad_weight=None, a trainable one does not,grad_inputis bit-identical between the two, and the frozen path performs strictly fewerdim1casts (verified by counting_to_mxfp8_dim1_kernel_wrappercalls). This fails onmain(the frozen path still runs the wgrad casts) and passes here.test_mxfp8_linear_frozen_weight_compile(EMULATED + AUTO): compiles the frozen module withfullgraph=Trueand checks the result matches eager andweight.grad is None, locking in compile-safety.Verified locally on sm120: the full
test_mxfp8_linear.pysuite passes;ruff checkandruff format --checkclean.Context
Frozen-base / parallel-adapter (LoRA) training on a quantized base is the motivating use case; see #4376 for the analogous float8 discussion and #4022 for the MXFP8 training roadmap. Touches the same
mx_mm.backwardas #4470 (a non-contiguousgrad_outputfix) but is independent of it.🤖 Generated with Claude Code