Skip to main content
Log in

Mojo function

flash_attention

flash_attention[rank: Int, type: DType, q_shape: DimList, //, use_score_mod: Bool = False, config: MHAConfig = MHAConfig(type, UInt(q_shape.get[::Int]()), UInt(q_shape.get[::Int]()), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), UInt(2 if _accelerator_arch().__contains__[::Bool,::Origin[$2]](__init__[__mlir_type.!kgen.string](":90")) else 4), UInt(1), FlashAttentionAlgorithm()), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: NDBuffer[type, rank, origin, shape, strides], q: NDBuffer[type, rank, origin, q_shape, strides], k: NDBuffer[type, rank, origin, shape, strides], v: NDBuffer[type, rank, origin, shape, strides], mask: NDBuffer[type, rank, origin, shape, strides, alignment=alignment, address_space=address_space, exclusive=exclusive], scale: SIMD[float32, 1], context: DeviceContextPtr = DeviceContextPtr(), num_partitions: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}))

flash_attention[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, q_shape: DimList, //, use_score_mod: Bool = False, config: MHAConfig = MHAConfig(type, UInt(q_shape.get[::Int]()), UInt(q_shape.get[::Int]()), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), UInt(2 if _accelerator_arch().__contains__[::Bool,::Origin[$2]](__init__[__mlir_type.!kgen.string](":90")) else 4), UInt(1), FlashAttentionAlgorithm()), ragged: Bool = False, decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: NDBuffer[type, rank, origin, shape, strides], q: NDBuffer[type, rank, origin, q_shape, strides], k: cache_t, v: cache_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: NDBuffer[uint32, 1, origin, shape, strides], scale: SIMD[float32, 1], ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}), kv_input_row_offsets: OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]] = OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]]({:i1 0, 1}), num_partitions: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}))

Flash attention 2 algorithm. Compute: (1) Transpose (Q) BSHD -> BHSD; (2) Transpose (K) BSHD -> BHSD; (3) Transpose (V) BSHD -> BHSD; (4) P = Bmm(Q, K), P is also called "score"; (5) P = P * scale + mask; (6) P = softmax(P); (7) O = Bmm(P, V) (8) Output = Transpose(O).

B, S, H, D denote batch size, sequence length, head count and depth, respectively. (1), (2), (3) happens while loading the data into shared memory. (8) happens when writing output to global memory.

All inputs (query, key, and value) must have BSHD layout. The mask can be BSS or BHSS.

This kernel also handles grouped attention optimization. In this case the shape of K and V are BShD where h = H / num_groups.

This kernels handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.

flash_attention[rank: Int, mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, q_shape: DimList, //, use_score_mod: Bool = False, config: MHAConfig = MHAConfig(type, UInt(q_shape.get[::Int]()), UInt(q_shape.get[::Int]()), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), UInt(2 if _accelerator_arch().__contains__[::Bool,::Origin[$2]](__init__[__mlir_type.!kgen.string](":90")) else 4), UInt(1), FlashAttentionAlgorithm()), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: NDBuffer[type, rank, origin, shape, strides], q: NDBuffer[type, rank, origin, q_shape, strides], k: NDBuffer[type, rank, origin, shape, strides], v: NDBuffer[type, rank, origin, shape, strides], mask_functor: mask_t, score_mod_functor: score_mod_t, scale: SIMD[float32, 1], ctx: DeviceContext, num_partitions: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}))

Was this page helpful?