Skip to main content

Mojo function

launch_mla_sm100_decode_enqueue_kernel

launch_mla_sm100_decode_enqueue_kernel[q_type: DType, KVLUTType: MHAOperand, output_type: DType, SplitAccumType: OptionalPointer, MaskType: MHAMask, config: MLA_SM100_Decode_Config, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool = False, ragged: Bool = False](q_tma: TMATensorTile[q_type, 2, IndexList(VariadicList(config.BM, config.BK0), Tuple()), _default_desc_shape[2, q_type, IndexList(VariadicList(config.BM, config.BK0), Tuple()), config.swizzle_mode]()], k_tma: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BK1, 1, config.BK0), Tuple()), config.kv_tma_swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BK1, 1, config.BK0), Tuple()), config.kv_tma_swizzle_mode]()], o_tma: TMATensorTile[output_type, 2, IndexList(VariadicList(config.out_rows, config.BN), Tuple()), _default_desc_shape[2, output_type, IndexList(VariadicList(config.out_rows, config.BN), Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType, lse_accum_split_ptr: SplitAccumType, scale: Float32, batch_size: Int, block_z: Int, num_partitions: Int, q_max_seq_len: Int, valid_len: ValidLengthType, mask: MaskType, scales_ptr: UnsafePointer[Float32, MutAnyOrigin], scalar_args_buf: LayoutTensor[DType.int64, scalar_args_buf.layout, scalar_args_buf.origin, element_layout=scalar_args_buf.element_layout, layout_int_type=scalar_args_buf.layout_int_type, linear_idx_type=scalar_args_buf.linear_idx_type, masked=scalar_args_buf.masked, alignment=scalar_args_buf.alignment], ctx: DeviceContext)

Was this page helpful?