Skip to main content

Mojo function

depth512_load

depth512_load[KVLUTType: MHAOperand, MaskType: MHAMask, qkv_dtype: DType, config: Depth512SM100Config[qkv_dtype], ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt](smem: Depth512AttentionSMem[config], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=Depth512SM100Config[qkv_dtype].BM, group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=Depth512SM100Config[qkv_dtype].num_qk_stages](), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=Depth512SM100Config[qkv_dtype].BM, group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=Depth512SM100Config[qkv_dtype].num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList((config // 2), 1, config, __list_literal__=Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config, 1, (config // 4), __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config, 1, (config // 4), __list_literal__=Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType)

Was this page helpful?