Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARM][feat]: Add KleidiAI Backend & enable 4 bit matmul operators #134124

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,6 @@
[submodule "third_party/NVTX"]
path = third_party/NVTX
url = https://github.com/NVIDIA/NVTX.git
[submodule "third_party/kleidiai"]
path = third_party/kleidiai
url = https://git.gitlab.arm.com/kleidi/kleidiai.git
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,8 @@ cmake_dependent_option(
cmake_dependent_option(BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
cmake_dependent_option(BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler"
OFF "USE_CUDA" OFF)
cmake_dependent_option(USE_KLEIDIAI "Use KleidiAI for the ARM CPU & AARCH64 architecture." ON
"CPU_AARCH64" OFF)
Copy link
Collaborator

@snadampal snadampal Aug 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is kleidiAI supported on all aarch64 device variants? I mean cortex-a57, a72, NeoverseN1/n2/v1/v2....?
it doesn't look like. so, it should be guarded with other flags like I8MM FEATURE or runtime cpuinfo checks.

Copy link
Collaborator Author

@ng-05 ng-05 Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KLEIDIAI is internally guarded at compile time and run time with I8MM and DOTPROD flags. This issue can be handled internally so no need to expose it to pytorch?


option(USE_MIMALLOC "Use mimalloc" OFF)
# Enable third party mimalloc library to improve memory allocation performance
Expand Down
6 changes: 6 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ new_local_repository(
path = "third_party/tensorpipe",
)

new_local_repository(
name = "kleidiai",
build_file = "//third_party:kleidiai/BUILD.bazel",
path = "third_party/kleidiai",
)

http_archive(
name = "mkl",
build_file = "//third_party:mkl.BUILD",
Expand Down
10 changes: 9 additions & 1 deletion aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ endif()
# XNNPACK
file(GLOB native_xnnpack "native/xnnpack/*.cpp")

# KLEIDIAI
file(GLOB native_kleidiai "native/kleidiai/*.cpp")
file(GLOB native_kleidiai_h "native/kleidiai/*.h")

# Add files needed from jit folders
append_filelist("jit_core_headers" ATen_CORE_HEADERS)
append_filelist("jit_core_sources" ATen_CORE_SRCS)
Expand Down Expand Up @@ -226,6 +230,10 @@ endif()
if(AT_MKL_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkl_cpp})
endif()
if(AT_KLEIDIAI_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${native_kleidiai})
include_directories(SYSTEM INTERFACE ${KLEIDIAI_INCLUDE_DIRS})
endif()
if(AT_MKLDNN_ENABLED)
set(all_cpu_cpp ${all_cpu_cpp} ${mkldnn_cpp})
endif()
Expand Down Expand Up @@ -570,7 +578,7 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/cmake-exports/ATenConfig.cmake"

set(INSTALL_HEADERS ${base_h} ${ATen_CORE_HEADERS} ${native_nested_h} ${ATen_TRANSFORMER_HEADERS})
if(NOT INTERN_BUILD_MOBILE)
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h})
list(APPEND INSTALL_HEADERS ${native_h} ${native_cpu_h} ${native_ao_sparse_h} ${native_quantized_h} ${cuda_h} ${native_kleidiai_h} ${native_cuda_h} ${native_hip_h} ${cudnn_h} ${hip_h} ${xpu_h} ${mps_h} ${native_mps_h} ${native_utils_h} ${miopen_h})
# Metal
if(USE_PYTORCH_METAL_EXPORT)
# Add files needed from exporting metal models(optimized_for_mobile)
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/Config.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@
#define AT_PARALLEL_NATIVE @AT_PARALLEL_NATIVE@
#define AT_BLAS_F2C() @AT_BLAS_F2C@
#define AT_BLAS_USE_CBLAS_DOT() @AT_BLAS_USE_CBLAS_DOT@
#define AT_KLEIDIAI_ENABLED() @AT_KLEIDIAI_ENABLED@
4 changes: 4 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,10 @@ bool Context::hasMKLDNN() {
#endif
}

bool Context::hasKleidiAI() {
return AT_KLEIDIAI_ENABLED();
ng-05 marked this conversation as resolved.
Show resolved Hide resolved
}

bool Context::hasOpenMP() {
#ifdef _OPENMP
return true;
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ class TORCH_API Context {
}
static bool hasOpenMP();
static bool hasMKL();
static bool hasKleidiAI();
static bool hasLAPACK();
static bool hasMKLDNN();
static bool hasMAGMA() {
Expand Down Expand Up @@ -525,6 +526,10 @@ inline bool hasMKL() {
return globalContext().hasMKL();
}

inline bool hasKleidiAI() {
return globalContext().hasKleidiAI();
}

inline bool hasLAPACK() {
return globalContext().hasLAPACK();
}
Expand Down
150 changes: 107 additions & 43 deletions aten/src/ATen/native/LinearAlgebra.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
#include <c10/util/irange.h>
#include <variant>

#if AT_KLEIDIAI_ENABLED()
#include <ATen/native/kleidiai/kai_kernels.h>
#endif

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
Expand Down Expand Up @@ -3429,78 +3433,138 @@ DEFINE_DISPATCH(int8pack_mm_stub);

Tensor _convert_weight_to_int4pack_cpu(
const Tensor& in,
int64_t innerKTiles) {
int64_t innerKTiles,
const int64_t qGroupSize,
int64_t N,
const Tensor& qScaleAndZeros,
const Tensor& bias) {
TORCH_CHECK(in.dim() == 2, __func__, " : expect weight to be 2D tensor.");
TORCH_CHECK(
in.dtype() == at::kByte, __func__, " : expect weight to be kByte.");
auto weight = in.contiguous();

TORCH_CHECK(in.dim() == 2,
__func__, " : expect weight to be 2D tensor.");
TORCH_CHECK(in.dtype() == at::kByte,
__func__, " : expect weight to be kByte.");
TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8,
__func__, " : innerKTiles need to be 2, 4, or 8, got ", innerKTiles);
int64_t K = weight.size(1) * 2;
std::vector<int64_t> weight_packed_dims;
at::TensorOptions weight_packed_options = at::TensorOptions().dtype(at::kInt);

auto weight = in.contiguous();
auto N = weight.size(0);
auto K = weight.size(1) * 2;
#if AT_KLEIDIAI_ENABLED()
TORCH_CHECK(
qGroupSize == 0 || qGroupSize == 32,
__func__,
": Group size should be 32 or 0. Provided ",
qGroupSize);

if (qGroupSize == 0) {
TORCH_CHECK(
qScaleAndZeros.numel() != 0,
__func__,
": Scales can't be null for channelwise weight packing");
} else {
TORCH_CHECK(
bias.numel() == 0,
__func__,
": Bias not supported in groupwise kernel");
K = (weight.size(0) / N) * qGroupSize;
}

const int64_t rhs_packed_size =
kleidiai::kai_pack_rhs_int4_size(N, K, qGroupSize);
weight_packed_dims = {rhs_packed_size};
weight_packed_options = weight_packed_options.dtype(at::kByte);
#else
TORCH_CHECK(
innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8,
__func__,
" : innerKTiles need to be 2, 4, or 8, got ",
innerKTiles);
N = weight.size(0);

// Create fake shapes for cpu. The meta registration in dynamo requires
// operator has the same output shape for each device. So creating a fake
// shape {N / 8, K / (16 * innerKTiles), 32, innerKTiles / 2}
constexpr int64_t kNTileSize = 8;
constexpr int64_t kKTileSize = 16;
const int64_t kSuperKTileSize = kKTileSize * innerKTiles;
auto nTiles = (N + kNTileSize - 1) / kNTileSize;
TORCH_CHECK(N % 16 == 0, __func__, " : expect N to be divisible by 16");
TORCH_CHECK(
K % kSuperKTileSize == 0,
__func__,
" : expect K to be divisible by ",
kSuperKTileSize);

TORCH_CHECK(N % 16 == 0,
__func__, " : expect N to be dividable by 16");
const int64_t kSuperKTileSize = kKTileSize * innerKTiles;
TORCH_CHECK( K % kSuperKTileSize == 0,
__func__, " : epxect K to be dividable by ", kSuperKTileSize);
auto kSuperTiles = (K + kSuperKTileSize - 1) / kSuperKTileSize;
weight_packed_dims = {nTiles, kSuperTiles, 32, innerKTiles / 2};
#endif

auto weight_packed = at::empty(
{nTiles, kSuperTiles, 32, innerKTiles / 2},
at::TensorOptions().dtype(at::kInt));
auto weight_packed = at::empty(weight_packed_dims, weight_packed_options);

#if AT_KLEIDIAI_ENABLED()
kleidiai::kai_pack_int4_rhs(
weight_packed, weight, qScaleAndZeros, bias, N, K, qGroupSize);
#else
weight_to_int4pack_stub(kCPU, weight_packed, weight, N, K);
#endif

return weight_packed;
}

Tensor _weight_int4pack_mm_cpu(
const Tensor& A,
const Tensor& B,
int64_t qGroupSize,
const Tensor& qScaleAndZeros) {

constexpr int64_t kNTileSize = 8;

const int64_t qGroupSize,
const Tensor& qScaleAndZeros,
int64_t N) {
auto M = A.size(0);
auto N = B.size(0) * kNTileSize;
auto K = A.size(1);

TORCH_CHECK(A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
__func__, " : expect A to be either 32-bit or 16-bit float tensor.");
TORCH_CHECK(A.is_contiguous(),
__func__, " : expect A to be contiguous.");
TORCH_CHECK(A.dim() == 2,
__func__, " : expect A to be 2D tensor.");
#if AT_KLEIDIAI_ENABLED()
TORCH_CHECK(
A.dtype() == kFloat,
__func__,
" : expect A to be either 32-bit float tensor.");
TORCH_CHECK(
qGroupSize == 0 || qGroupSize == 32,
__func__,
": Group size should be 32 or 0");
#else
constexpr int64_t kNTileSize = 8;
N = B.size(0) * kNTileSize;

TORCH_CHECK(B.dtype() == kInt,
__func__, " : expect B to be int32 tensor.");
TORCH_CHECK(B.is_contiguous(),
__func__, " : expect B to be contiguous.");
TORCH_CHECK(B.dim() == 4,
__func__, " : expect B to 4d tensor.");
TORCH_CHECK(
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
__func__,
" : expect A to be either 32-bit or 16-bit float tensor.");
TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous.");
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");

TORCH_CHECK(qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128
|| qGroupSize == 256,
__func__, ": expect qGroupSize to be 32, 64, 128 or 256, got ", qGroupSize);
TORCH_CHECK(B.dtype() == kInt, __func__, " : expect B to be int32 tensor.");
TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(B.dim() == 4, __func__, " : expect B to 4d tensor.");

TORCH_CHECK(qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N
&& qScaleAndZeros.size(2) == 2,
__func__, ": expect qScaleAndZeros to be 3d tensor with sizes [:, ", N, ", 2]");
TORCH_CHECK(
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
qGroupSize == 256,
__func__,
": expect qGroupSize to be 32, 64, 128 or 256, got ",
qGroupSize);

TORCH_CHECK(
qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(1) == N &&
qScaleAndZeros.size(2) == 2,
__func__,
": expect qScaleAndZeros to be 3d tensor with sizes [:, ",
N,
", 2]");
#endif

auto C = at::empty({M, N}, A.options());
int4pack_mm_stub(kCPU, C, A, B, qGroupSize, qScaleAndZeros, N, K);

#if AT_KLEIDIAI_ENABLED()
kleidiai::kai_quant_pack_lhs_int4_mm(C, A, B, M, N, K, qGroupSize);
#else
int4pack_mm_stub(kCPU, C, A, B, qGroupSize, qScaleAndZeros, N, K);
#endif
return C;
}

Expand Down
13 changes: 8 additions & 5 deletions aten/src/ATen/native/cuda/int4mm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1093,12 +1093,12 @@ __global__ void matrix_to_m16n8k16_Bint4_layout(

#endif


at::Tensor _weight_int4pack_mm_cuda(
const at::Tensor& A,
const at::Tensor& B,
int64_t qGroupSize,
const at::Tensor& qScaleAndZeros) {
const int64_t qGroupSize,
const at::Tensor& qScaleAndZeros,
int64_t N) {
c10::cuda::CUDAGuard g(A.device());

TORCH_CHECK(
Expand Down Expand Up @@ -1285,7 +1285,11 @@ at::Tensor _weight_int4pack_mm_cuda(
// output is [n / 8][k / (InnerKTiles * 16)][32][innerKTiles / 2] (int32 dtype)
at::Tensor _convert_weight_to_int4pack_cuda(
const at::Tensor& in,
int64_t innerKTiles) {
int64_t innerKTiles,
const int64_t qGroupSize,
int64_t N,
const Tensor& qScaleAndZeros,
const Tensor& bias) {
c10::cuda::CUDAGuard g(in.device());

TORCH_CHECK(in.dim() == 2);
Expand Down Expand Up @@ -1360,5 +1364,4 @@ at::Tensor _convert_weight_to_int4pack_cuda(
return out;
}


} // namespace at::native
Loading
Loading