For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
dequant_mxfp4
def dequant_mxfp4[*, SF_VECTOR_SIZE: Int = Int(32)](ctx: DeviceContext, output: TileTensor[Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], input: TileTensor[Storage=input.Storage, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], scales: TileTensor[Storage=scales.Storage, address_space=scales.address_space, linear_idx_type=scales.linear_idx_type, element_size=scales.element_size], num_rows: Int, num_cols: Int, pdl_level: PDLLevel = PDLLevel())
Dequantize MXFP4 packed weights to FP8 or BF16.
Args:
- βctx (
DeviceContext): Device context for kernel launch. - βoutput (
TileTensor[Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output tensor [num_rows, num_cols] of float8_e4m3fn or bfloat16. - βinput (
TileTensor[Storage=input.Storage, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size]): Input tensor [num_rows, num_cols // 2] of uint8 (packed FP4). - βscales (
TileTensor[Storage=scales.Storage, address_space=scales.address_space, linear_idx_type=scales.linear_idx_type, element_size=scales.element_size]): Scale tensor [num_rows, num_cols // SF_VECTOR_SIZE] of float8_e8m0fnu. - βnum_rows (
Int): Number of rows (N dimension for weights). - βnum_cols (
Int): Number of columns (K dimension, unpacked). - βpdl_level (
PDLLevel): PDL optimization level for kernel launch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!