Mojo function
quantize_mxfp4_amd
quantize_mxfp4_amd[out_dtype: DType = DType.uint8, scales_dtype: DType = DType.float8_e8m0fnu, in_dtype: DType = DType.bfloat16, //, *, num_max_threads: Int = 512](ctx: DeviceContext, output_tile: TileTensor[out_dtype, output_tile.LayoutType, output_tile.origin, linear_idx_type=output_tile.linear_idx_type, element_size=output_tile.element_size], scales_tile: TileTensor[scales_dtype, scales_tile.LayoutType, scales_tile.origin, linear_idx_type=scales_tile.linear_idx_type, element_size=scales_tile.element_size], input_tile: TileTensor[in_dtype, input_tile.LayoutType, input_tile.origin, linear_idx_type=input_tile.linear_idx_type, element_size=input_tile.element_size])
Quantize BF16 activations to MXFP4 on AMD CDNA4 (MI355X).
Produces packed uint8 output and 2D E8M0 block scales compatible with dequant_mxfp4() and V_MFMA_SCALE_F32_16X16X128_F8F6F4.
NOTE: The 2D scales layout is a stand-in. The optimized CDNA4 layout will likely be 6D (32x32 tiles) or 7D (16x16 tiles), mirroring how SM100 uses a 5D interleaved layout for its tensor core scale feed.
Args:
- βctx (
DeviceContext): Device context. - βoutput_tile (
TileTensor): Output [M, K//2] uint8 (packed FP4). - βscales_tile (
TileTensor): Output [M, K//32] float8_e8m0fnu (block scales). - βinput_tile (
TileTensor): Input [M, K] bfloat16.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!