Mojo function
fa4_load
fa4_load[KVLUTType: MHAOperand, MaskType: MHAMask, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt](smem: SM100AttentionSMem[config], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, 4 if config.fuse_gqa else 3, _padded_shape[4 if config.fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=config.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[4 if config.fuse_gqa else 3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.qk_depth, decoding=False, fuse_gqa=config.fuse_gqa, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config.k_rows_per_cta(), 1, config, __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config.k_rows_per_cta(), 1, config, __list_literal__=Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config, 1, config.v_cols_per_cta(), __list_literal__=Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config, 1, config.v_cols_per_cta(), __list_literal__=Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!