Skip to main content

Mojo function

mxfp4_matmul_sm90

mxfp4_matmul_sm90(c: TileTensor[c.dtype, c.LayoutType, c.origin, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a: TileTensor[a.dtype, a.LayoutType, a.origin, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b_packed: TileTensor[b_packed.dtype, b_packed.LayoutType, b_packed.origin, address_space=b_packed.address_space, linear_idx_type=b_packed.linear_idx_type, element_size=b_packed.element_size], b_scales: TileTensor[b_scales.dtype, b_scales.LayoutType, b_scales.origin, address_space=b_scales.address_space, linear_idx_type=b_scales.linear_idx_type, element_size=b_scales.element_size], ctx: DeviceContext)

MXFP4 matmul: dequant B weights to FP8, cast A to FP8, SM90 FP8 GEMM.

Args:

  • c (TileTensor): Output [M, N] in bfloat16.
  • a (TileTensor): Activations [M, K] in bfloat16.
  • b_packed (TileTensor): Weights [N, K//2] in uint8 (packed MXFP4).
  • b_scales (TileTensor): Weight scales [N, K//32] in float8_e8m0fnu.
  • ctx (DeviceContext): Device context.

Was this page helpful?