Opening this issue to track adding Block Scaled Dot Products for MX types.
References:
We can basically emit a __op$block_scaled_dot custom call to trigger the proper lowerings. Note that is properly lowers to the right cuDNN call on supported hardware.
Should the underlying compute doesn't support the data types (ie pre Blackwell for NVIDIA) or the plugin doesn't support the pass, we can always take the verbose route according to the OpenXLA discussion:
func.func @dequantize(%x : tensor<128xf8E5M2>, %x_scale : tensor<4xf8E8M0FNU>) -> tensor<128xf16> {
// Step 1: Convert both tensors to FP16.
%x_f16 = stablehlo.convert %x : (tensor<128xf8E5M2>) -> tensor<128xf16>
%x_scale_f16 = stablehlo.convert %x : (tensor<4xf8E8M0FNU>) -> tensor<4xf16>
// Step 2: Broadcast and reshape scale tensor.
%x_scale_f16_broadcast = stablehlo.broadcast_in_dim %x_scale_f16, dims = [0] : (tensor<4xf16>) -> tensor<4x32xf16>
%x_scale_f16_reshape = stablehlo.reshape %x_scale_f16_broadcast : (tensor<4x32xf16>) -> tensor<128xf16>
// Step 3: Multiply the tensors.
%result = stablehlo.multiply %x_f16, %x_scale_f16_reshape : tensor<128xf16>
return %result : tensor<128xf16>
}
Opening this issue to track adding Block Scaled Dot Products for MX types.
References:
We can basically emit a
__op$block_scaled_dotcustom call to trigger the proper lowerings. Note that is properly lowers to the right cuDNN call on supported hardware.Should the underlying compute doesn't support the data types (ie pre Blackwell for NVIDIA) or the plugin doesn't support the pass, we can always take the verbose route according to the OpenXLA discussion: