Mojo function
depth512_load
depth512_load[KVLUTType: MHAOperand, MaskType: MHAMask, qkv_dtype: DType, config: Depth512SM100Config[qkv_dtype], ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, is_leader: Bool](smem: Depth512AttentionSMem[config], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, 4 if config.fuse_gqa else 3, _padded_shape[4 if config.fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=config.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[4 if config.fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=config.BM, group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=config.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows((config // 2), KVLUTType.page_size), 1, config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows((config // 2), KVLUTType.page_size), 1, config, __list_literal__=NoneType(None)), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BK1, KVLUTType.page_size), 1, config, __list_literal__=NoneType(None)), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(kv_sub_tile_rows(config.BK1, KVLUTType.page_size), 1, config, __list_literal__=NoneType(None)), config.swizzle_mode]()], kv_lut: KVLUTType)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!