Skip to content

fix(Int8DynActInt4WeightQuantizer): use scales_precision for runtime modules#4459

Open
Anai-Guo wants to merge 2 commits into
pytorch:mainfrom
Anai-Guo:fix/int8dynact-int4-scales-precision
Open

fix(Int8DynActInt4WeightQuantizer): use scales_precision for runtime modules#4459
Anai-Guo wants to merge 2 commits into
pytorch:mainfrom
Anai-Guo:fix/int8dynact-int4-scales-precision

Conversation

@Anai-Guo

@Anai-Guo Anai-Guo commented Jun 2, 2026

Copy link
Copy Markdown
Contributor

Summary

Int8DynActInt4WeightQuantizer._convert_for_runtime passed self.precision as the scales_precision argument of replace_linear_8da4w, so the runtime Int8DynActInt4WeightLinear modules were built with the wrong scales dtype. The code even carried a # TODO: this should be self.scales_precision? next to the line.

This diverges from _create_quantized_state_dict, which uses self.scales_precision for the scales. As reported in #2571, that mismatch makes quantize(model).state_dict() inconsistent with _create_quantized_state_dict(model) whenever precision != scales_precision.

Fix

Pass self.scales_precision (instead of self.precision) as the scales_precision argument, matching _create_quantized_state_dict.

 replace_linear_8da4w(
     model,
     self.groupsize,
     self.padding_allowed,
     self.precision,
-    # TODO: this should be self.scales_precision?
-    self.precision,
+    self.scales_precision,
 )

Fixes #2571

Test plan

  • Instantiate Int8DynActInt4WeightQuantizer(precision=torch.float32, scales_precision=torch.bfloat16), quantize a small linear model, and confirm the runtime module scales dtype now matches _create_quantized_state_dict output.

🤖 Generated with Claude Code

…modules

_convert_for_runtime passed self.precision as the scales_precision argument
to replace_linear_8da4w, so the runtime Int8DynActInt4WeightLinear modules
were created with the wrong scales dtype. This diverged from
_create_quantized_state_dict, which uses self.scales_precision, producing an
inconsistency between quantize(model).state_dict() and
_create_quantized_state_dict(model). Pass self.scales_precision instead.

Fixes pytorch#2571
@pytorch-bot

pytorch-bot Bot commented Jun 2, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4459

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 2 Unclassified Failures

As of commit de16b05 with merge base 28e6aca (image):

NEW FAILURES - The following jobs have failed:

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 2, 2026
@jerryzh168 jerryzh168 requested a review from metascroy June 2, 2026 17:40
Asserts Int8DynActInt4WeightQuantizer with a scales_precision distinct from
precision builds runtime Int8DynActInt4WeightLinear modules whose scales/zeros
buffers use scales_precision. Fails on the pre-fix behavior, which passed
self.precision as the scales_precision argument of replace_linear_8da4w.

Signed-off-by: Tai An <antai12232931@outlook.com>
@Anai-Guo Anai-Guo requested a review from vkuzo as a code owner June 11, 2026 10:21
@andrewor14

Copy link
Copy Markdown
Contributor

@claude review

@claude

claude Bot commented Jun 11, 2026

Copy link
Copy Markdown

Claude finished @andrewor14's task in 56s —— View job


Review: fix(Int8DynActInt4WeightQuantizer): use scales_precision for runtime modules

I've reviewed the change against origin/main. The fix is correct and well-targeted.

The bug

_convert_for_runtime builds the runtime Int8DynActInt4WeightLinear modules via replace_linear_8da4w, whose 5th positional argument is scales_precision (linear_quant_modules.py:521-522). That value flows into the scales/zeros buffer dtypes (linear_quant_modules.py:443-456). Previously the quantizer passed self.precision there, so the runtime buffers were allocated with the activation precision instead of the scales precision.

This matters because of how quantize() works:

state_dict = self._create_quantized_state_dict(model)   # scales in self.scales_precision
model = self._convert_for_runtime(model)                # buffers were in self.precision (bug)
model.load_state_dict(state_dict, strict=False)          # copy_ casts to the buffer dtype

load_state_dict copies into the existing buffers, so the destination dtype wins. With precision=fp32, scales_precision=bf16, the bf16 scales produced by _create_quantized_state_dict were silently up-cast to fp32 on load — exactly the inconsistency reported in #2571. The one-line change makes _convert_for_runtime consistent with _create_quantized_state_dict, and the stale # TODO: this should be self.scales_precision? is correctly removed.

Test

test_8da4w_quantizer_scales_precision is a good regression test: it picks precision != scales_precision (the only configuration that exposes the bug), asserts both scales.dtype and zeros.dtype are bf16, and runs a forward pass. It would fail on the pre-fix code. Nice coverage.

Minor (non-blocking) suggestions

  • Tighten the assertion to also cover precision. To fully document the contract, you could also assert the non-scales paths still honor precision — e.g. mod.precision == torch.float32 and mod.bias.dtype == torch.float32 (the bias buffer is allocated with precision at linear_quant_modules.py:459). This guards against a future "fix" that over-corrects by routing everything through scales_precision.
  • Imports inside the test. The local from ... import statements are fine and match the lazy-import style elsewhere in this file, so no change needed — just noting Int8DynActInt4WeightQuantizer and Int8DynActInt4WeightLinear may already be importable at module top depending on the test file's existing imports.

Neither is required to land. The core fix is minimal, correct, and matches the existing state-dict path.
• branch fix/int8dynact-int4-scales-precision

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Int8DynActInt4WeightQuantizer quantize() not propagating the scales_precision argument

2 participants