Mojo function
dequant_mxfp4
dequant_mxfp4[*, SF_VECTOR_SIZE: Int = 32](ctx: DeviceContext, output: TileTensor[output.dtype, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], input: TileTensor[input.dtype, input.LayoutType, input.origin, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], scales: TileTensor[scales.dtype, scales.LayoutType, scales.origin, address_space=scales.address_space, linear_idx_type=scales.linear_idx_type, element_size=scales.element_size], num_rows: Int, num_cols: Int, pdl_level: PDLLevel = PDLLevel())
Dequantize MXFP4 packed weights to FP8 or BF16.
Args:
- ctx (
DeviceContext): Device context for kernel launch. - output (
TileTensor): Output tensor [num_rows, num_cols] of float8_e4m3fn or bfloat16. - input (
TileTensor): Input tensor [num_rows, num_cols // 2] of uint8 (packed FP4). - scales (
TileTensor): Scale tensor [num_rows, num_cols // SF_VECTOR_SIZE] of float8_e8m0fnu. - num_rows (
Int): Number of rows (N dimension for weights). - num_cols (
Int): Number of columns (K dimension, unpacked). - pdl_level (
PDLLevel): PDL optimization level for kernel launch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!