Mojo function
mxfp4_matmul_sm90
mxfp4_matmul_sm90(c: TileTensor[c.dtype, c.LayoutType, c.origin, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a: TileTensor[a.dtype, a.LayoutType, a.origin, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b_packed: TileTensor[b_packed.dtype, b_packed.LayoutType, b_packed.origin, address_space=b_packed.address_space, linear_idx_type=b_packed.linear_idx_type, element_size=b_packed.element_size], b_scales: TileTensor[b_scales.dtype, b_scales.LayoutType, b_scales.origin, address_space=b_scales.address_space, linear_idx_type=b_scales.linear_idx_type, element_size=b_scales.element_size], ctx: DeviceContext)
MXFP4 matmul: dequant B weights to FP8, cast A to FP8, SM90 FP8 GEMM.
Args:
- c (
TileTensor): Output [M, N] in bfloat16. - a (
TileTensor): Activations [M, K] in bfloat16. - b_packed (
TileTensor): Weights [N, K//2] in uint8 (packed MXFP4). - b_scales (
TileTensor): Weight scales [N, K//32] in float8_e8m0fnu. - ctx (
DeviceContext): Device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!