Mojo function
dequant_mxfp4
dequant_mxfp4[*, SF_VECTOR_SIZE: Int = 32](ctx: DeviceContext, output: TileTensor[address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], input: TileTensor[address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], scales: TileTensor[address_space=scales.address_space, linear_idx_type=scales.linear_idx_type, element_size=scales.element_size], num_rows: Int, num_cols: Int, pdl_level: PDLLevel = PDLLevel())
Dequantize MXFP4 packed weights to FP8 or BF16.
Args:
- βctx (
DeviceContext): Device context for kernel launch. - βoutput (
TileTensor[address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output tensor [num_rows, num_cols] of float8_e4m3fn or bfloat16. - βinput (
TileTensor[address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size]): Input tensor [num_rows, num_cols // 2] of uint8 (packed FP4). - βscales (
TileTensor[address_space=scales.address_space, linear_idx_type=scales.linear_idx_type, element_size=scales.element_size]): Scale tensor [num_rows, num_cols // SF_VECTOR_SIZE] of float8_e8m0fnu. - βnum_rows (
Int): Number of rows (N dimension for weights). - βnum_cols (
Int): Number of columns (K dimension, unpacked). - βpdl_level (
PDLLevel): PDL optimization level for kernel launch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!