Skip to main content

Mojo function

launch_mla_sm100_decode_fp8_per_token_scale_rope_aware

launch_mla_sm100_decode_fp8_per_token_scale_rope_aware[q_type: DType, KVLUTType: MHAOperand, output_type: DType, SplitAccumType: OptionalPointer, MaskType: MHAMask, config: MLA_SM100_Decode_Config, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, ragged: Bool = False, has_per_token_scales: Bool = False](q_nope_tma: TMATensorTile[DType.float8_e4m3fn, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, DType.float8_e4m3fn, IndexList(config, config, __list_literal__=Tuple()), config.content_swizzle_mode]()], q_rope_tma: TMATensorTile[DType.bfloat16, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, DType.bfloat16, IndexList(config, config, __list_literal__=Tuple()), config.rope_swizzle_mode]()], k_content_tma: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config, 1, config, __list_literal__=Tuple()), config.content_swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config, 1, config, __list_literal__=Tuple()), config.content_swizzle_mode]()], k_rope_tma: TMATensorTile[DType.bfloat16, 3, _padded_shape[3, DType.bfloat16, IndexList(config, 1, config, __list_literal__=Tuple()), config.rope_swizzle_mode](), _ragged_shape[3, DType.bfloat16, IndexList(config, 1, config, __list_literal__=Tuple()), config.rope_swizzle_mode]()], scale_tma: TMATensorTile[DType.float32, 2, IndexList(1, config, __list_literal__=Tuple())], o_tma: TMATensorTile[output_type, 2, IndexList(config, config, __list_literal__=Tuple()), _default_desc_shape[2, output_type, IndexList(config, config, __list_literal__=Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType, lse_accum_split_ptr: SplitAccumType, scale: Float32, batch_size: Int, block_z: Int, num_partitions: Int, q_max_seq_len: Int, valid_len: ValidLengthType, mask: MaskType, q_scale_ptr: UnsafePointer[Float32, MutAnyOrigin], scalar_args_buf: TileTensor[DType.int64, scalar_args_buf.LayoutType, scalar_args_buf.origin, linear_idx_type=scalar_args_buf.linear_idx_type, element_size=scalar_args_buf.element_size], ctx: DeviceContext)

Launch the FP8 per-token-scale rope-aware MLA decode kernel with split content/rope TMAs.

This is a dedicated launch function for the SnapMLA FP8 per-token-scale rope-aware path. Q and K are split into FP8 content (512 dims, SWIZZLE_64B) and BF16 rope (64 dims, SWIZZLE_128B), requiring 5 TMA descriptors (content, rope, scales, Q_nope, Q_rope). Per-token scales are loaded via TMA alongside content and rope.

Was this page helpful?