Skip to main content

Mojo function

fa4_load

fa4_load[KVLUTType: MHAOperand, MaskType: MHAMask, config: FA4Config[config.qkv_dtype, rope_dtype=config.rope_dtype, scale_dtype=config.scale_dtype], ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt](smem: SM100AttentionSMem[config], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.qk_depth, decoding=False, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BN, 1, config.BK0), Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BN, 1, config.BK0), Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BN, 1, config.padded_ov_depth), Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(VariadicList(config.BN, 1, config.padded_ov_depth), Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType)

Was this page helpful?