Skip to content

Block Scaled Dot Product #294

@steeve

Description

@steeve

Opening this issue to track adding Block Scaled Dot Products for MX types.

References:

We can basically emit a __op$block_scaled_dot custom call to trigger the proper lowerings. Note that is properly lowers to the right cuDNN call on supported hardware.

Should the underlying compute doesn't support the data types (ie pre Blackwell for NVIDIA) or the plugin doesn't support the pass, we can always take the verbose route according to the OpenXLA discussion:

func.func @dequantize(%x : tensor<128xf8E5M2>, %x_scale : tensor<4xf8E8M0FNU>) -> tensor<128xf16> {
  // Step 1: Convert both tensors to FP16.
  %x_f16 = stablehlo.convert %x : (tensor<128xf8E5M2>) -> tensor<128xf16>
  %x_scale_f16 = stablehlo.convert %x : (tensor<4xf8E8M0FNU>) -> tensor<4xf16>

  // Step 2: Broadcast and reshape scale tensor.
  %x_scale_f16_broadcast = stablehlo.broadcast_in_dim %x_scale_f16, dims = [0] : (tensor<4xf16>) -> tensor<4x32xf16>
  %x_scale_f16_reshape = stablehlo.reshape %x_scale_f16_broadcast : (tensor<4x32xf16>) -> tensor<128xf16>

  // Step 3: Multiply the tensors.
  %result = stablehlo.multiply %x_f16, %x_scale_f16_reshape : tensor<128xf16>
  return %result : tensor<128xf16>
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions