Skip to content

m-a-n-i-f-e-s-t/vidrial

Repository files navigation

Vidrial Cover

A mixed CUDA/Python framework to write deep learning kernels. With Vidrial you can:

  • Write non-spaghetti CUDA cuda code. By systematically separating static and dynamic code, Vidrial kernels are clean, testable and make heavy use of reusable components.
  • Call your kernels from python with pytorch tensors via a JIT compliler. Sweep over many configurations to find the fastest one for your given problem shape.

We have implemented a bunch of kernels. Whatever you are trying to implement, it's quite likely that you can use one of them as a starting point.

Kernel Description
mma matrix multiplication
reduce generic reduction (e.g. reduce sum, reduce prod...)
sympow symmetric power of a vector
sympow_bwd backward pass for for sympow
sympow_mma fused C = sympow(A) @ B
mma_sympow_bwd backward pass for for sympow_mma
flash_attn flash attention

Calling Vidrial Kernels from Python

Vidrial ops can be called like any other pytorch operation.

import torch as th
from vidrial.ops import mm

b, n = 16, 128
A = th.randn([b,n,n], requires_grad=True, device="cuda")
B = th.randn([b,n,n], requires_grad=True, device="cuda")

C = mm(A,B)
C_ref = A @ B
assert th.allclose(C, C_ref, atol=5e-2, rtol=5e-2)

C.sum().backward()
A_grad, B_grad = A.grad.clone(), B.grad.clone()

A.grad.zero_()
B.grad.zero_()
C_ref.sum().backward()

assert th.allclose(A_grad, A.grad, atol=5e-2, rtol=5e-2)
assert th.allclose(B_grad, B.grad, atol=5e-2, rtol=5e-2)

By default, Vidrial's JIT system runs in PickAny mode. It finds 1 valid configuration, compiles and runs it. But whenever performance is critical, you should run vidrial kernels in PickBest mode, which launches a parallel sweep of configurations to find the best one. Compiled kernels and their timings are cached in vidrial/jit/.jit_cache and vidrial/jit/.timing_cache respectively. To better understnad what the JIT system is doing under the hood, tun on logging.

import torch as th
import vidrial.jit as vjit
from vidrial.ops import sympow
import logging
logging.basicConfig(level=logging.INFO)

X = th.randn([1024, 64], device="cuda")

print("Run a sweep over 16 configurations across 8 workers to find the best one")
with vjit.settings.set(vjit.PickBest, max_workers=8, max_configs=16):
    Y = sympow(X, power=2, d_tile=8)

print("Call the fastest kernel from the sweep (already in the JIT cache)")
Y = sympow(X, power=2, d_tile=8)

print("Different shapes trigger new compilations")
X = th.randn([4, 32], device="cuda")
Y = sympow(X, power=2, d_tile=8)

# If your model calls several vidrial kernels and you don't want the compilation to be sequential you should use precompile mode
X = th.randn([4, 64], device="cuda", requires_grad=True)
print("Compile the fwd and bwd pass kernels in parallel")
with vjit.settings.set(vjit.PickAny, precompile=True):
    Y = sympow(X, power=2, d_tile=8)
    Y.sum().backward()
# WARNING: under precompile mode, vidrial ops return incorrect values. Rerun the computation once all the kernels are in the JIT cache
Y = sympow(X, power=2, d_tile=8)
Y.sum().backward()

Example Kernel

Vidrial CUDA kernels follow a very specific structuure that makes it possible to write clean, reusable and testable code.

  • We avoid any pointer arithmetic by relying heavily on CUTE Layouts.
  • Separation of static and dynamic code. Most Vidrial functions receive a cfg object as their first argument. The config is contains only static (compile time) information. It specifies things like tensor shapes, memory layouts, number of threads, copy and mma instructions...
  • Rigid naming conventions. For example, we use different words to refer to tensors that belong to different parts of the compute hierarchy. GPU tensors are called slabs, CTA tensors are called tiles and thread tensors are called fragments.

This is what a fully featured matrix multiplication kernel looks like. It uses tensor cores, pipelined data movement from global to shared memory using vectorized 128-bit loads, and efficient loading of shared memory data into registers with LDSM instructions.

template <typename Cfg, typename T>
__global__ void tiled_mma_kernel(Cfg cfg, T* A_ptr, T* B_ptr, T* C_ptr) {
    int tid = threadIdx.x;
    int bid_M = blockIdx.x; int bid_N = blockIdx.y; int bid_P = blockIdx.z;
    auto tile_coords = MmaMNKCoords(cfg.MNK_tile_shape);
    tile_coords.step_M(blockIdx.x); tile_coords.step_N(blockIdx.y); tile_coords.step_P(blockIdx.z);
    // ----- Global memory slabs -----
    auto gA_slab = make_tensor(make_gmem_ptr(A_ptr), cfg.A.gSlab);
    auto gB_slab = make_tensor(make_gmem_ptr(B_ptr), cfg.B.gSlab);
    auto gC_slab = make_tensor(make_gmem_ptr(C_ptr), cfg.C.gSlab);
    // ----- Shared memory pipelines -----
    constexpr static int smempipe = static_min(cfg.perf.smempipe, cfg.K_tile_num);
    auto pipe = SmemPipe<smempipe>();
    extern __shared__ char smem[];
    Allocator<16> alloc(smem);
    T* A_smem = alloc.allocate<T>(size(cfg.A.sTile) * smempipe);
    T* B_smem = alloc.allocate<T>(size(cfg.B.sTile) * smempipe);
    auto sA_tile_pipe = pipe.create(A_smem, cfg.A.sTile, cfg.A.tile_copy);
    auto sB_tile_pipe = pipe.create(B_smem, cfg.B.sTile, cfg.B.tile_copy);
    auto rA_frg_mma = cfg.A.make_mma_frg();
    auto rB_frg_mma = cfg.B.make_mma_frg();
    auto rC_frg_mma = cfg.C.make_mma_frg();
    clear(rC_frg_mma);
    // ----- Pipeline fetch A_tile and B_tile -----
    auto pipe_fetch = [&]() {
        if (tile_coords.valid_K_tile(cfg.K)) {
            auto gA_tile = tile_coords.slice_A_tile(gA_slab);
            auto gB_tile = tile_coords.slice_B_tile(gB_slab);
            pipe.fetch(gA_tile, sA_tile_pipe, cfg.A.tile_copy);
            pipe.fetch(gB_tile, sB_tile_pipe, cfg.B.tile_copy);
            tile_coords.step_K();
        }
        pipe.commit();
    };
    // ----- Prefill pipeline -----
    for (; pipe.stage < smempipe - 1; pipe.step())
        pipe_fetch();
    // ----- Main loop -----
    for (int k_tile = 0; k_tile < cfg.K_tile_num; k_tile++) {
        pipe_fetch();
        pipe.ready();
        auto sA_tile = pipe.read(sA_tile_pipe);
        auto sB_tile = pipe.read(sB_tile_pipe);
        load_frg<T, cfg.perf.use_ldsm, false>(sA_tile, cfg.A.mma_FrgThr, rA_frg_mma);
        if constexpr (cfg.perf.regpipe == 0)
            load_frg<T, cfg.perf.use_ldsm, false>(sB_tile, cfg.B.mma_FrgThr, rB_frg_mma);
        vidrial::gemm(cfg, rA_frg_mma, rB_frg_mma, sB_tile, rC_frg_mma);
        pipe.step();
    }
    // ----- Write C_tile to global memory -----
    alloc.reset(smem);
    T* C_smem = alloc.allocate<T>(size(cfg.C.sTile));
    auto sC_tile = make_tensor(make_smem_ptr(C_smem), cfg.C.sTile);
    copy(rC_frg_mma, slice_rest(sC_tile, cfg.C.mma_FrgThr, tid));
    __syncthreads();
    auto gC_tile = tile_coords.slice_C_tile(gC_slab);
    CTA_copy_tile(cfg.C.tile_copy, sC_tile, gC_tile);
}

Installation

Vidrial relies on the nvcc compiler for it's JIT system. Make sure to have it installed.

If you want to run the already existing vidrial kernels you can install it from PyPI.

pip install vidrial

If you want to develop your own kernels, clone the repository and install it in editable mode.

git clone https://github.com/m-a-n-i-f-e-s-t/vidrial.git
cd vidrial
pip install -e .

To ensure the installation is working, run the tests:

pytest -n auto vidrial/

About

A framework for clean, testable, and high-performance CUDA kernels.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors