Refactor int4 and int8 weight only quantization to use quantize#301
Conversation
…antize` Summary: Previously we added `quantize` as a general API (pytorch#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags:
Summary: Similar to pytorch#294 we replaced the implementation of int8 weight only quant to used the newly added `quantize` function, as a part of the unification effort for affine quantization Test Plan: 1. unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756 elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629 elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368 2. integration test: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Reference: elapsed_time: 1.355208740234375 milliseconds After refactor: elapsed_time: 1.32778857421875 milliseconds code diff (gist): https://gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc code diff (meta-only paste): https://www.internalfb.com/phabricator/paste/view/P1387333845 Reviewers: Subscribers: Tasks: Tags:
…antize` Summary: Previously we added `quantize` as a general API (pytorch#256) for Affine Quantized tensor subclass, and also tensor subclass based dtype conversion in general. The plan is to use this to replace existing quant APIs including int4 weight only, int8 weight only, int8 dynamic quant and 8da4w (for executorch). This PR we started replacing the implementation of int8 dynamic quant API with `quantize` API with affine quantized tensor subclass. We'll make sure the performance does not regress for vit model. Test Plan: TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py reference: elapsed_time: 1.4821058654785155 milliseconds after refactor: elapsed_time: 1.4804757690429688 milliseconds generated code diff: https://gist.github.com/jerryzh168/90c71107a5aaaa5d8dd2170c573e076d Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/301
Note: Links to docs will display an error until the docs builds have been completed. ✅ You can merge normally! (1 Unrelated Failure)As of commit 74ecb09 with merge base 729fa4d ( BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
quantize
|
this is rebased on int8-wo PR (#299) so will need to update this PR after the int8-wo PR is landed |
quantizequantize
| ) | ||
|
|
||
| @parameterized.expand(COMMON_DEVICE_DTYPE) | ||
| @unittest.skipIf(TORCH_VERSION_AFTER_2_4, "skip because there is some bug in inductor codegen") |
There was a problem hiding this comment.
is there an issue or a short description of both bugs we can add, otherwise will be hard to remember when to remove the skipIf
There was a problem hiding this comment.
it's just a inductor c++ compilation bug I think, I'm planning to open a PR after this, I have opened one for the other skip here: #300
| return decorator | ||
|
|
||
| def get_aqt_layout_cls(extended_layout: str) -> Callable: | ||
| def get_aqt_layout_cls_ctr(extended_layout: str) -> Callable: |
There was a problem hiding this comment.
this means constructor, since we are returning class.from_plain now
There was a problem hiding this comment.
This needs a comment I don't believe ctr is a common abbreviation for constructor
| # int_data = int_data.view(shape) | ||
| # changed = self.from_plain(int_data, scale, zero) | ||
| # return changed | ||
| # TODO: changing shape is no-op for int4 packed weight right now |
There was a problem hiding this comment.
could you share some more detail on this I'm quite curious
There was a problem hiding this comment.
yeah, I'm confirming with @HDCharles right now, I think this is pretty weird, see comments in L575 of aqt.py for more details
|
|
||
| @classmethod | ||
| def from_plain(cls, int_data, scale, zero_point): | ||
| # TODO: expose the arg |
There was a problem hiding this comment.
this one needs a bit more discussions with pt core team
| if extended_layout == "tensor_core_tiled": | ||
| from torchao.quantization.utils import find_multiple | ||
| orig_out_features, orig_in_features = input_float.shape | ||
| in_features = find_multiple(orig_in_features, 1024) |
There was a problem hiding this comment.
where do the constants for 1024 and 8 come from?
There was a problem hiding this comment.
this is specific to tinygemm kernels I think, copied from old code:
ao/torchao/quantization/subclass.py
Lines 585 to 586 in 8a4e693
| torchao.apply_dynamic_quant(model) | ||
| from torch._inductor import config as inductorconfig | ||
| inductorconfig.force_fuse_int_mm_with_mul = True | ||
| # int8 act, int8 weight dynamic quantization |
There was a problem hiding this comment.
should we delete code here instead of commenting it?
There was a problem hiding this comment.
sure, this is just for people to easily try out different APIs, but we can just ask people to copy paste from README as well
| # groupwise int4 quantization | ||
| groupsize = weight_qtensor.block_size[-1] | ||
| if not _from_flinear: | ||
| weight_qtensor = weight_qtensor.t() |
There was a problem hiding this comment.
n00b q: why does this require a transpose?
There was a problem hiding this comment.
this is to align the dimensions, for block_size so that we can get groupsize from block_size argument, see L662, and also related to L575. right now the _quantized_linear does not have a well-defined accepted weight shape, we need to fix that
|
|
||
| def _quantized_linear_op(input_tensor, weight_qtensor, bias): | ||
| def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True): | ||
| # TODO: the old tensor subclass can use the single implementation for both F.linear dispatch |
There was a problem hiding this comment.
@msaroufim see this comment for more details
| return decorator | ||
|
|
||
| def get_aqt_layout_cls(extended_layout: str) -> Callable: | ||
| def get_aqt_layout_cls_ctr(extended_layout: str) -> Callable: |
There was a problem hiding this comment.
This needs a comment I don't believe ctr is a common abbreviation for constructor
| filter_fn, | ||
| ) | ||
| if TORCH_VERSION_AFTER_2_4: | ||
| quantize(model, get_apply_int4wo_quant(**kwargs), filter_fn) |
There was a problem hiding this comment.
blind kwargs make it impossible to document the behavior. i understand that change_linear_weights_to_int4_woqtensors has this as well. Seems like something that could be worth fixing.
| return measurement.mean * 1e6 | ||
|
|
||
|
|
||
| def find_multiple(n: int, *args: Tuple[int]) -> int: |
There was a problem hiding this comment.
we now use this in torchao/dtypes and torchao/quantization and have to do import tricks to avoid circular dep
Summary: This is similar to pytorch#294 but applied for int4 weight only quantization Test Plan: unit perf test: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297 elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314 elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793 integration perf test: reference: elapsed_time: 2.5900126953125 milliseconds after refactor: elapsed_time: 2.56680078125 milliseconds diff: no diff TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py Before: After: generated code diff: Reviewers: Subscribers: Tasks: Tags:
quantizequantize
|
Please don't merge PRs when CI is red and we can't get signal for incremental changes. Fix main CI first, then merge. |
makes sense, sorry about this, will do next time |
Summary:
This is similar to #294 but applied for int4 weight only quantization
Test Plan:
unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int4_wo_quant_perf
elapsed time: 0.2166275215148926, ref elapsed time: 0.2191881561279297
elapsed time: 0.2376406478881836, ref elapsed time: 0.22721023559570314
elapsed time: 0.21919679641723633, ref elapsed time: 0.2154969596862793
integration perf test:
reference: elapsed_time: 2.5900126953125 milliseconds
after refactor: elapsed_time: 2.56680078125 milliseconds
diff: no diff
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py
Before:
After:
generated code diff:
Reviewers:
Subscribers:
Tasks:
Tags:
Refactor int8 weight only quant to use quantize #299 logs
Summary:
Similar to #294 we replaced the implementation
of int8 weight only quant to used the newly added quantize function, as a part of
the unification effort for affine quantization
Test Plan:
unit perf test:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_wo_quant_perf
elapsed time: 0.23909856796264647, ref elapsed time: 0.25150911331176756
elapsed time: 0.24894208908081056, ref elapsed time: 0.2570047950744629
elapsed time: 0.21607391357421876, ref elapsed time: 0.22809568405151368
integration test:
TORCH_LOGS='output_code' python tutorials/quantize_vit/run_vit_b_quant.py
Reference: elapsed_time: 1.355208740234375 milliseconds
After refactor: elapsed_time: 1.32778857421875 milliseconds
code diff (gist): gist.github.com/jerryzh168/921a722cf20d476c8fc5888482e722dc
code diff (meta-only paste): internalfb.com/phabricator/paste/view/P1387333845