Mojo function
launch_mla_sm100_decode_enqueue_kernel
launch_mla_sm100_decode_enqueue_kernel[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, config: MLA_SM100_Decode_Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, _use_valid_length: Bool = False, _is_cache_length_accurate: Bool = False, ragged: Bool = False](q_tma: TMATensorTile[KVLUTType.dtype, tile_layout_k_major[KVLUTType.dtype, config.BM, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[KVLUTType.dtype, 2, IndexList[2, DType.int64](config.BM, config.BN, Tuple[]()), config.swizzle_mode]()], k_tma: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BM, 1, config.BN, Tuple[]()), config)], o_tma: TMATensorTile[output_type, tile_layout_k_major[output_type, config.out_rows, config.BN, config.swizzle_mode](), _tma_desc_tile_layout[output_type, 2, IndexList[2, DType.int64](config.out_rows, config.BN, Tuple[]()), config.swizzle_mode]()], kv_lut: KVLUTType, scale: Float32, batch_size: Int, num_partitions: Int, max_cache_valid_length: Int, valid_len: ValidLengthType, mask: MaskType, score_mod: ScoreModType, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!