Skip to main content

Mojo function

fa4_load

fa4_load[KVLUTType: MHAOperand, MaskType: MHAMask, config: FA4Config, ValidLengthType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt](mbars: FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=EnableForcedOrdering], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False, num_qk_stages=config.num_qk_stages](), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False, num_qk_stages=config.num_qk_stages](), config.swizzle_mode]()], k_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config.BN, 1, config.BK0, Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config.BN, 1, config.BK0, Tuple()), config.swizzle_mode]()], v_tma_op: TMATensorTile[KVLUTType.dtype, 3, _padded_shape[3, KVLUTType.dtype, IndexList(config.BN, 1, config.padded_depth, Tuple()), config.swizzle_mode](), _ragged_shape[3, KVLUTType.dtype, IndexList(config.BN, 1, config.padded_depth, Tuple()), config.swizzle_mode]()], kv_lut: KVLUTType)

Was this page helpful?