Skip to main content

Mojo function

dequant_mxfp4

dequant_mxfp4[*, SF_VECTOR_SIZE: Int = 32](ctx: DeviceContext, output: TileTensor[output.dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], input: TileTensor[input.dtype, input.LayoutType, input.origin, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], scales: TileTensor[scales.dtype, scales.LayoutType, scales.origin, address_space=scales.address_space, linear_idx_type=scales.linear_idx_type, element_size=scales.element_size], num_rows: Int, num_cols: Int, pdl_level: PDLLevel = PDLLevel())

Dequantize MXFP4 packed weights to FP8 or BF16.

Args:

  • ctx (DeviceContext): Device context for kernel launch.
  • output (TileTensor): Output tensor [num_rows, num_cols] of float8_e4m3fn or bfloat16.
  • input (TileTensor): Input tensor [num_rows, num_cols // 2] of uint8 (packed FP4).
  • scales (TileTensor): Scale tensor [num_rows, num_cols // SF_VECTOR_SIZE] of float8_e8m0fnu.
  • num_rows (Int): Number of rows (N dimension for weights).
  • num_cols (Int): Number of columns (K dimension, unpacked).
  • pdl_level (PDLLevel): PDL optimization level for kernel launch.

Was this page helpful?