Mojo function
launch_mla_sm100_decode_native_fp8
launch_mla_sm100_decode_native_fp8[q_type: DType, KVLUTType: MHAOperand, output_type: DType, SplitAccumType: OptionalPointer, MaskType: MHAMask, config: MLA_SM100_Decode_Config, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, ragged: Bool = False](q_tma: TMATensorTile[KVLUTType.dtype, 2, IndexList(config.BM, config.BK0, Tuple()), _default_desc_shape[2, KVLUTType.dtype, IndexList(config.BM, config.BK0, Tuple()), config.kv_tma_swizzle_mode]()], k_tma: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config.BK1, 1, config.BK0, Tuple()), config.kv_tma_swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config.BK1, 1, config.BK0, Tuple()), config.kv_tma_swizzle_mode]()], o_tma: TMATensorTile[output_type, 2, IndexList(config.out_rows, config.BN, Tuple()), _default_desc_shape[2, output_type, IndexList(config.out_rows, config.BN, Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType, lse_accum_split_ptr: SplitAccumType, scale: Float32, batch_size: Int, block_z: Int, num_partitions: Int, q_max_seq_len: Int, valid_len: ValidLengthType, mask: MaskType, scales_ptr: UnsafePointer[Float32, MutAnyOrigin], scalar_args_buf: LayoutTensor[DType.int64, scalar_args_buf.layout, scalar_args_buf.origin, element_layout=scalar_args_buf.element_layout, layout_int_type=scalar_args_buf.layout_int_type, linear_idx_type=scalar_args_buf.linear_idx_type, masked=scalar_args_buf.masked, alignment=scalar_args_buf.alignment], ctx: DeviceContext)
Launch the native FP8 MLA decode kernel with FP8 Q TMA.
This is a dedicated launch function for the native FP8 path because the Q TMA has FP8 dtype (SWIZZLE_64B) instead of BF16 (SWIZZLE_128B).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!