Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Integration of KleidiAI 4-Bit MatMul Kernels into PyTorch #137830

Open
ng-05 opened this issue Oct 10, 2024 · 10 comments
Open

RFC: Integration of KleidiAI 4-Bit MatMul Kernels into PyTorch #137830

ng-05 opened this issue Oct 10, 2024 · 10 comments
Labels
matrix multiplication oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ng-05
Copy link
Collaborator

ng-05 commented Oct 10, 2024

RFC: Integration of KleidiAI 4-Bit MatMul Kernels into PyTorch

Hardware Platform: aarch64

Motivation

  1. Add support for KleidiAI 4-bit matrix multiplication kernels to PyTorch.
  2. Expose these 4-bit matmul kernels as PyTorch operators to accelerate ML models.

Pytorch Integration PR: #134124
torch/ao Quantization Change: ng-05/ao@cbcf915

Overview of 4-Bit MatMul Operations

The 4-bit matmul process consists of two main operations:

  1. Packing: Quantized weights (4-bit), floating-point scales and biases are combined into a single packed_weight buffer.
  2. Execution: Perform matrix multiplication on FP32 input using the packed_weight buffer.

Current Target Operations

We aim to enhance two existing PyTorch operations:

  1. _convert_weight_to_int4pack_cpu: This operation is responsible for converting weights to a packed format.
  2. _weight_int4pack_mm_cpu: This operation performs the matrix multiplication using the packed weights.

Issues Identified

  1. Modification of Operation Signatures:

    • For _convert_weight_to_int4pack_cpu:
      • Reason: The KleidiAI packing kernel requires scales, bias, and 4-bit quantized weights to create a single linear buffer.
      • Change Link
    • For _weight_int4pack_mm_cpu:
      • Reason: The packing operation loses the shape information of the original weights (N, K). The packed_weight shape is 1 dimensional and accounts for scales, bias , weights. We need the row count (N) to perform matmul correctly.
      • Change Link
  2. Support for Channelwise and Groupwise Quantization:

    • Need: We require a way to pass group size information to the underlying kernel for channelwise and groupwise quantized matmuls. Currently, a group size of 0 is used to indicate channelwise quantization and 32 to indicate 32 groupwise quantization
  3. Data Type Handling in Operations:

    • Issue: The existing _weight_int4pack_mm_cpu operator performs multiplication and accumulation in FP32/BF16. In contrast, our kernels dynamically quanitzes fp32 input to INT8, use INT8 for multiplication and accumulate results in FP32. This might introduce noticeable (but within acceptable error range [mean error : 0.0064]) accuracy changes for the same operation across platforms.
    • Proposed Solution: Introduce a context manager to handle low-bit operations explicitly within the same PyTorch operator, conveying the datatype difference to the user. Someting similar and as simple as torch._C._set_cpu_allow_fp16_reduced_precision_reduction

Initial Approach:

  • We initially introduced new PyTorch operators due to above mentioned issues and the lack of support for 4-bit matmul operations using 8-bit intrinsics. However, this approach was quickly flagged by the Meta pytorch team.

Why Kleidi in pytorch?

  • Kleidi will expand beyond lowbit matmul kernels, we would like to add support for Kleidi kernels directly in existing pytorch operators so that optimizations can be leveraged directly for ML Models

Proposed Path Forward

To resolve the above issues and successfully integrate our kernels into _convert_weight_to_int4pack_cpu and _weight_int4pack_mm_cpu, we seek suggestions and collaboration on the following:

  • Adjusting the operation signatures to accommodate the necessary parameters.
  • Finding a method to pass group size information for quantization.
  • Ensuring proper handling of data types during the matmul process.
  • Considering the implications of the context manager approach.

E2E Flow

Our torch ao quantizer implementation can directly replace GPTFast block in below diagram

Kleidiai_Pytorch_flow_2

Existing 4 bit matmul kernel fucntionality in _weight_int4pack_mm_cpu operartion :

FP32 Input                INT4 Weights     FP32 Scales
[32.1, 15.7, -3.9]        [3, -2, 1]     [0.3, 0.2, 0.15]
       |                      |            / 
       |                      |           / 
       |                Dequantization   / 
       |                      |         /
       |                      v        /
       |             FP32 Dequantized Weights
       |            [0.375, -0.25, 0.125]
       |                      |
       |                      |
       +----------------------+
                  |
                  |
             Matrix Multiplication
                  |
                  v
         FP32 Output
    [12.09375, -3.9375, -0.48125]

KleidiAI 4 bit matmul kernel fucntionality in _weight_int4pack_mm_cpu operartion :

FP32 Input                Packed Buffer ( INT4 Weights + scales + Bias )
[32.1, 15.7, -3.9]        [3, -2, 1, 0]    [8.25, 3.5]  [1.2, 5.5]
       |                       |
       |                       |
    Quantize              Dequantize
   to INT8                 to INT8
       |                       |
       v                       v
INT8 Input              INT8 Weights
 [64, 31, -8]           [24, -16, 8, 0]
       \                     /
        \                   /
         \                 /
          \               /
           \             /
            \           /
             \         /
           INT8 Matrix Multiplication
                   |

                   |
                   v
 FP32 Dequantized and Accumulate in FP32
         [12.0, -3.875, 1.5, 0]

                   |
                   v
             FP32 Final Output
            [9.625, -3.875, 1.5]

Your feedback and suggestions are highly welcome!

cc: @malfet @digantdesai @jgong5 @sanchitintel @cfRod @milpuz01

cc @jerryzh168 @jianyuh @raghuramank100 @jamesr66a @vkuzo @jgong5 @Xia-Weiwen @leslie-fang-intel @msaroufim

@ezyang ezyang transferred this issue from pytorch/pytorch Oct 11, 2024
@cfRod
Copy link
Collaborator

cfRod commented Oct 11, 2024

@ezyang could you let us know why the RFC has been moved to torchao?

@ezyang ezyang transferred this issue from pytorch/ao Oct 12, 2024
@ezyang ezyang added oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module matrix multiplication labels Oct 12, 2024
@ezyang
Copy link
Contributor

ezyang commented Oct 12, 2024

@jerryzh168 @vkuzo who should answer on this

@milpuz01
Copy link
Contributor

cc @malfet @digantdesai

@jerryzh168
Copy link
Contributor

We initially introduced new PyTorch operators due to above mentioned issues and the lack of support for 4-bit matmul operations using 8-bit intrinsics. However, this approach was quickly flagged by the Meta pytorch team.

can you give more context here? personally I feel a separate op might make more sense since the packing etc. are pretty different from existing op

@ng-05
Copy link
Collaborator Author

ng-05 commented Oct 16, 2024

We initially introduced new PyTorch operators due to above mentioned issues and the lack of support for 4-bit matmul operations using 8-bit intrinsics. However, this approach was quickly flagged by the Meta pytorch team.

can you give more context here? personally I feel a separate op might make more sense since the packing etc. are pretty different from existing op

  1. I think their initial argument was that pytorch aten ops are more in maintenance mode instead of development mode. The pytorch team was not in favour of adding 2 new ops ( pack & mm ) to aten.
  2. I think it was also required to support new op over all major hardware platforms ( Intel, CUDA etc ) and we initially provided just aarch64 implementation via kleidiai.
  3. They wanted the lowbit ops to be part of torch/ao repo but kleidiai is not limited to just lowbit ops. We need kleidiai in pytorch to accelerate fp32/fp16/bf16/8bit/4bit matmuls in broader scope.

The existing _convert_weight_to_int4pack and _weight_int4pack_mm op signature can not cater to all lowbit usecases.

1 . _convert_weight_to_int4pack : does not provide option to pack bias and scales , quantized weights in single packed_weight buffer.
2. _weight_int4pack_mm : It might not be realistic to expect original weight tensor shapes ( out_features, in_features) from packed_weight buffer, because its packed/padded/reshaped etc .

@jerryzh168
Copy link
Contributor

can you just add low bit related kernels to torchao? I think even _convert_weight_to_int4pack and _weight_int4pack_mm should live in torchao, but not sure if we have talked about this before or not.

for fp32/fp16/bf16 kernels that are not related to low bits, does these require the same kind of packing? if so, maybe it's better to just add to torchao as well

@ng-05
Copy link
Collaborator Author

ng-05 commented Oct 17, 2024

can you just add low bit related kernels to torchao? I think even _convert_weight_to_int4pack and _weight_int4pack_mm should live in torchao, but not sure if we have talked about this before or not.

for fp32/fp16/bf16 kernels that are not related to low bits, does these require the same kind of packing? if so, maybe it's better to just add to torchao as well

Thanks for your inputs. We had this discussion with @digantdesai and @malfet and they asked us to pursue this integration and raise a RFC.
Moreover @digantdesai mentioned that low-bit gemm integration with inductor isn't fully clear yet via torch/ao
So we are unsure what we should do as next steps.

@digantdesai
Copy link
Contributor

digantdesai commented Oct 23, 2024

Thanks for the RFC. Here's my high-level thinking, some of which we already discussed asynchronously, so it shouldn't be a surprise. I'm listing some nuances/feedback for the proposed approach in this RFC, as well as for an alternative approach I suggested, which involves housing custom ops in TorchAO.

Regarding the proposed approach in this RFC:

  • ➕ We already have a precedence for this, and it seems to have been working fine for x86, CUDA, and MPS. However, AFAIK, I don't believe int4 has been integrated with Inductor beyond potential external op level.

  • ➖ I feel we still need to leverage quantization APIs from TorchAO, which means the ops in core seem a bit out of place to me.

  • ➖ Dynamic quantization of activation under weight-only quantization schema is not a good idea IMHO. We already have a matching schema in TorchAO for this. I'm unsure what the rationale is for overloading weight-only quant ops. Context manager might be an OK solution.

  • I don't have a strong preference on the signature besides, it should be coherent with existing use of the APIs. I.e. can we pass a bias to x86 packing fn?

Regarding the TorchAO custom ops approach:

  • ➕ For low-bit dtypes (i.e., 1-7b int or even float8) you have to go through the quantization flow. TorchAO quantization APIs seem to me the right place to house such custom ops with custom packing/layout etc.

  • ➕/➖ The TorchAO custom op path should work for Eager (and ExecuTorch). For Inductor, besides calling it as an external op, it's unclear to me ATM how easy it would be to leverage something custom out-of-tree.

  • ➖ The stability of TorchAO low-bit ops, especially for CPUs is still experimental and evolving, which is one of my concerns.

  • ➖ This may not be a viable path for native dtypes; i.e., if we have a bf16 @ bf16 = fp32, a user may not need to involve TorchAO at all, thus we must have an implementation in core.

I also discussed with @malfet and he has some different opinions, I tried to capture some of them here. And given he has better visibility in to the native op developments, I would rely on his judgement on what is the best path forward here.

@kimishpatel
Copy link
Contributor

We initially introduced new PyTorch operators due to above mentioned issues and the lack of support for 4-bit matmul operations using 8-bit intrinsics. However, this approach was quickly flagged by the Meta pytorch team.

can you give more context here? personally I feel a separate op might make more sense since the packing etc. are pretty different from existing op

Unfortunately even the existing ops convert/mm APIs are not really consist. Each backend outputs tensor in different form for prepack op. @malfet this is the thing that you highlighted first and it has stuck to me every since. We shouldnt really have backend specific prepacking hidding behind the same aten op, since each is doing a different thing.

@kimishpatel
Copy link
Contributor

can you just add low bit related kernels to torchao? I think even _convert_weight_to_int4pack and _weight_int4pack_mm should live in torchao, but not sure if we have talked about this before or not.

for fp32/fp16/bf16 kernels that are not related to low bits, does these require the same kind of packing? if so, maybe it's better to just add to torchao as well

Agree with @jerryzh168 here. Also, I think backend specific custom op, in this case arm specific custom op, makes more sense. This makes the API clear that this ops either output packed weights (prepack op), or accept packed weights (mm op) that is only interpretable by specific implementation, e.g. ARM's implementation. Thus CUDA or x86 impl of mm op cannot accept weights packed for ARM's implementation. And the best way to make this clear is to have separate custom ops. This also helps to clarify the API that activations are quantized dynamically.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
matrix multiplication oncall: quantization Quantization support in PyTorch triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

7 participants