Skip to main content

Mojo function

flash_attention

flash_attention[dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(SIMD(Int[IntTuple](q_layout.shape[2])), SIMD(Int[IntTuple](q_layout.shape[3])), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), SIMD(4), SIMD(1), FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: LayoutTensor[output.dtype, output.layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: LayoutTensor[k.dtype, k.layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[v.dtype, v.layout, v.origin, element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], mask: LayoutTensor[mask.dtype, mask.layout, mask.origin, element_layout=mask.element_layout, layout_int_type=mask.layout_int_type, linear_idx_type=mask.linear_idx_type, masked=mask.masked, alignment=mask.alignment], scale: Float32, context: DeviceContextPtr = DeviceContextPtr(), num_partitions: Optional[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None)

flash_attention[cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(SIMD(Int[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), SIMD(Int[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), SIMD(4), SIMD(1), FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), ragged: Bool = False, sink: Bool = False, decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: LayoutTensor[output.dtype, output.layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: cache_t, v: cache_t, mask_functor: mask_t, valid_length: LayoutTensor[DType.uint32, valid_length.layout, valid_length.origin, element_layout=valid_length.element_layout, layout_int_type=valid_length.layout_int_type, linear_idx_type=valid_length.linear_idx_type, masked=valid_length.masked, alignment=valid_length.alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: Optional[Int] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None, num_partitions: Optional[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None, decode_dispatch_metadata: OptionalReg[MHADecodeDispatchMetadata] = None)

Flash attention 2 algorithm. Compute: (1) Transpose (Q) BSHD -> BHSD; (2) Transpose (K) BSHD -> BHSD; (3) Transpose (V) BSHD -> BHSD; (4) P = Bmm(Q, K), P is also called "score"; (5) P = P * scale + mask; (6) P = softmax(P); (7) O = Bmm(P, V) (8) Output = Transpose(O).

B, S, H, D denote batch size, sequence length, head count and depth, respectively. (1), (2), (3) happens while loading the data into shared memory. (8) happens when writing output to global memory.

All inputs (query, key, and value) must have BSHD layout. The mask can be BSS or BHSS.

This kernel also handles grouped attention optimization. In this case the shape of K and V are BShD where h = H / num_groups.

This kernels handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.

flash_attention[mask_t: MHAMask, dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(SIMD(Int[IntTuple](q_layout.shape[2])), SIMD(Int[IntTuple](q_layout.shape[3])), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), SIMD(4), SIMD(1), FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, _use_valid_length: Bool = False, _padded_ndbuffer: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: LayoutTensor[output.dtype, output.layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: LayoutTensor[k.dtype, k.layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[v.dtype, v.layout, v.origin, element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], mask_functor: mask_t, scale: Float32, ctx: DeviceContext, num_partitions: Optional[Int] = None, valid_length: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None)

Was this page helpful?