Skip to content

Conversation

@andrewor14
Copy link
Contributor

@andrewor14 andrewor14 commented Sep 11, 2025

Summary: Similar to #2937, this commit improves the prepare vs convert SQNR of int4 weight-only QAT from 12 to 45. This is achieved by mimicking the numerics of the target FBGEMM bf16-int4 kernel more closely. In particular, the FBGEMM kernel:

  1. Performs asymmetric [0, 15] quant first then recenters to 8
  2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125)
  3. Quantizes the weights using min val instead of zero points

Unit tests:

python test/quantization/test_qat.py -k test_quantize_api_int4
python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives

End-to-end tests:

Fine-tuning Llama3.1-8B with and without this PR in unsloth:

  • fine-tune for 1 epoch on yahma/alpaca-cleaned with LoRA
  • batch size 8, learning rate 2e-4, no gradient accumulation

Wikitext:

  • QAT int4 quantized model (with this PR) achieved 33% lower perplexity than the int4 baseline
  • QAT int4 quantized model without this PR was worse
==> unsloth_model_lora_baseline_output/lm_eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |7.5551|±  |   N/A|

==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.7655|±  |   N/A|

# QAT without this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.3548|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.0683|±  |   N/A|

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2986

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 2bc59a1 with merge base 10ba659 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 11, 2025
self._test_quantize_api_against_ptq(
Int4WeightOnlyConfig(version=version),
target_prepare_sqnr=12,
Int4WeightOnlyConfig(version=version, int4_packing_format=packing_format),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel it's fine for QAT to only support version 2

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

although you may want to cover more int4 packing format such as TILE_PACKED_TO_4D the previous tinygemm layout

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think we can drop version 1, but it's BC breaking so we can do it separately

Comment on lines +183 to +188
fbgemm_symmetric_qmax = 8
w_grouped = w.to(torch.float32).view(w.shape[0], -1, self.config.group_size)
max_val = torch.amax(w_grouped, dim=-1, keepdim=True)
min_val = torch.amin(w_grouped, dim=-1, keepdim=True)
scale = torch.clamp(max_val - min_val, min=eps) / qmax
zero_point = min_val + scale * fbgemm_symmetric_qmax
Copy link
Contributor

@jerryzh168 jerryzh168 Sep 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why don't we call int4_row_quantize_zp and get the scale/zero_point from there? is it because of performance concerns?

I guess we could ask fbgemm to add another function to just compute scale/zero_point so we can call it here in the future

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah and also they cast the quantized values to int8, which we don't want to do here

@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label Sep 12, 2025
**Summary:** Similar to #2937, this commit improves the prepare
vs convert SQNR of int4 weight-only QAT from 12 to 45. This is
achieved by mimicking the numerics of the target FBGEMM bf16-int4
kernel more closely. In particular, the FBGEMM kernel:

1. Performs asymmetric [0, 15] quant first then recenters to 8
2. Uses smaller scale eps of 1e-6 instead of bf16's eps (0.0078125)
3. Quantizes the weights using min val instead of zero points

**Unit tests:**

```
python test/quantization/test_qat.py -k test_quantize_api_int4
python test/quantization/test_qat.py -k test_fbgemm_int4_weight_only_primitives
```

**End-to-end tests:**

Fine-tuning Llama3.1-8B with and without this PR in unsloth:

- fine-tune for 1 epoch on yahma/alpaca-cleaned with LoRA
- batch size 8, learning rate 2e-4, no gradient accumulation

Wikitext:

- QAT int4 quantized model (with this PR) achieved 33% lower
  perplexity than the int4 baseline
- QAT int4 quantized model without this PR was worse

```
==> unsloth_model_lora_baseline_output/lm_eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |7.5551|±  |   N/A|

==> unsloth_model_lora_baseline_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.7655|±  |   N/A|

# QAT without this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |8.3548|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.0683|±  |   N/A|
```
@andrewor14
Copy link
Contributor Author

@jerryzh168 I updated the PR description with end-to-end tasks, can you take another look?

@andrewor14 andrewor14 merged commit ea8c00f into main Sep 15, 2025
18 checks passed
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
Details TBD.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
Details TBD.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 04f6bce
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 04f6bce
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 04f6bce
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 23, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 47019f4
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 24, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 24, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 24, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: d8f7eff
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 24, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 24, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 24, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: d0120f0
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: cb5a5e1
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: cb5a5e1
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: 25b4383
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 25, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: ecbff90
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: ecbff90
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

ghstack-source-id: a707a59
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: a707a59
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 633bc65
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 633bc65
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 633bc65
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 26, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 77f47b7
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 30, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 77f47b7
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 30, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: 77f47b7
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 30, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 30, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
andrewor14 added a commit that referenced this pull request Sep 30, 2025
**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

ghstack-source-id: bb1356c
Pull Request resolved: #3050
andrewor14 added a commit that referenced this pull request Sep 30, 2025
* Support NVFP4 dynamic per tensor scale

**Summary:** This commit adds an option for the existing
`NVFP4InferenceConfig` to dynamically compute an appropriate
fp32 per tensor scale to support the two level scaling
according to the NVFP4 specification:
https://developer.nvidia.com/blog/introducing-nvfp4-for-efficient-and-accurate-low-precision-inference/.

While two level scaling is supported in `NVFP4Tensor`, today
there is no config API for users to call this. The existing
`NVFP4InferenceConfig` only supports single level scaling
because including an explicit `per_tensor_scale` field would
make serialization tricky.

In the future, we should add an end-to-end calibration flow
so users can compute an appropriate per tensor scale for the
activations first, and then pass this to `NVFP4Tensor` as a
static scale, similar to the proposal in #2572.

**Test Plan:**
```
pytest test/prototype/mx_formats/test_inference_workflow.py -k test_inference_workflow_nvfp4
pytest test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

Also did a quick benchmark before and after:
```
import copy
import time
import torch
from torchao.quantization import quantize_
from torchao.prototype.mx_formats import NVFP4InferenceConfig

m_mx1 = torch.nn.Linear(64, 256, bias=True, dtype=torch.bfloat16, device="cuda")
m_mx2 = copy.deepcopy(m_mx1)
config1 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=False)
config2 = NVFP4InferenceConfig(use_dynamic_per_tensor_scale=True)
quantize_(m_mx1, config=config1)
quantize_(m_mx2, config=config2)
m_mx1 = torch.compile(m_mx1, fullgraph=True, backend="aot_eager")
m_mx2 = torch.compile(m_mx2, fullgraph=True, backend="aot_eager")

start = time.time()
for _ in range(1000):
    m_mx1(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("No per_tensor_scale = ", time.time() - start, "seconds")

start = time.time()
for _ in range(1000):
    m_mx2(torch.randn(128, 64, device="cuda", dtype=torch.bfloat16))
print("With per_tensor_scale = ", time.time() - start, "seconds")
```

On a single B200:
```
No per_tensor_scale =  1.2855589389801025 seconds
With per_tensor_scale =  1.3009123802185059 seconds
```

[ghstack-poisoned]

* Improve QAT nvfp4 numerics

**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
Details TBD.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`,
   but in `torch.int32` instead of `torch.uint8`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to 36 with `use_per_tensor_scale`, and 12 to inf without.
This is achieved by mimicking the PTQ flow more closely,
in particular, in descending order of significance:

1. Simulate `f4_unpacked_to_f32` and `f32_to_f4_unpacked`
2. Do not cast intermediate fake quantized values to original
   dtype, e.g. bf16 which loses some fidelity from fp32
3. Fake round blockwise scales to float8

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimick the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Test Plan:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

End-to-end tests TBD.

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986, this commit improves the prepare vs convert SQNR of NVFP4 QAT from 12 to inf. This is achieved by refactoring NVFP4 QAT to mimic the PTQ numerics exactly, using a new linear class to incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> unsloth_model_lora_qat_int4_output/lm_eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]

* Update base for Update on "Improve QAT nvfp4 numerics"


**Summary:** Similar to #2986,
this commit improves the prepare vs convert SQNR of NVFP4 QAT
from 12 to inf. This is achieved by refactoring NVFP4 QAT to
mimic the PTQ numerics exactly, using a new linear class to
incorporate both the quantization and mm logic.

**Unit tests:**
```
python test/quantization/test_qat.py -k test_qat_nvfp4
python test/quantization/test_qat.py -k test_quantize_api_nvfp4
```

**End-to-end tests:**
Fine-tuning Llama3.2-3B with and without this PR in axolotl:
- fine-tune for 1 epoch on yahma/alpaca-cleaned
- batch size 512, learning rate 2e-5, no gradient accumulation

Wikitext:

- With this PR, QAT nvfp4 quantized model achieved 15% lower
  perplexity than the quantized baseline
- Without this PR, QAT nvfp4 quantized model was about the
  same as the quantized baseline

```
==> Llama3.2-3B_baseline_bs512/eval_float.log <==
|        |       |none  |     0|word_perplexity|↓  |9.418|±  |   N/A|

==> Llama3.2-3B_baseline_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.3681|±  |   N/A|

# QAT with this PR (quantized)
==> Llama3.2-3B_qat_bs512/eval_quantized.log <==
|        |       |none  |     0|word_perplexity|↓  |10.2281|±  |   N/A|
```

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants