IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

flash_attention

def flash_attention[dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(Int[IntTuple](q_layout.shape[2]), Int[IntTuple](q_layout.shape[3]), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: LayoutTensor[element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: LayoutTensor[element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], mask: LayoutTensor[element_layout=mask.element_layout, layout_int_type=mask.layout_int_type, linear_idx_type=mask.linear_idx_type, masked=mask.masked, alignment=mask.alignment], scale: Float32, context: DeviceContext, num_partitions: Optional[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None)

def flash_attention[cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(Int[IntTuple](q_layout.shape[(q_layout.rank() - 2)]), Int[IntTuple](q_layout.shape[(q_layout.rank() - 1)]), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), ragged: Bool = False, sink: Bool = False, decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: LayoutTensor[element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: cache_t, v: cache_t, mask_functor: mask_t, valid_length: LayoutTensor[DType.uint32, element_layout=valid_length.element_layout, layout_int_type=valid_length.layout_int_type, linear_idx_type=valid_length.linear_idx_type, masked=valid_length.masked, alignment=valid_length.alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: Optional[Int] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None, num_partitions: Optional[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None, decode_dispatch_metadata: OptionalReg[MHADecodeDispatchMetadata] = None)

Flash attention 2 algorithm. Compute: (1) Transpose (Q) BSHD -> BHSD; (2) Transpose (K) BSHD -> BHSD; (3) Transpose (V) BSHD -> BHSD; (4) P = Bmm(Q, K), P is also called "score"; (5) P = P * scale + mask; (6) P = softmax(P); (7) O = Bmm(P, V) (8) Output = Transpose(O).

B, S, H, D denote batch size, sequence length, head count and depth, respectively. (1), (2), (3) happens while loading the data into shared memory. (8) happens when writing output to global memory.

All inputs (query, key, and value) must have BSHD layout. The mask can be BSS or BHSS.

This kernel also handles grouped attention optimization. In this case the shape of K and V are BShD where h = H / num_groups.

This kernels handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.

def flash_attention[mask_t: MHAMask, dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(Int[IntTuple](q_layout.shape[2]), Int[IntTuple](q_layout.shape[3]), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, _use_valid_length: Bool = False, _padded_ndbuffer: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: LayoutTensor[element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: LayoutTensor[element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], mask_functor: mask_t, scale: Float32, ctx: DeviceContext, num_partitions: Optional[Int] = None, valid_length: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None)

def flash_attention[mask_t: MHAMask, dtype: DType, output_type: DType, q_tt_layout: TensorLayout, k_tt_layout: TensorLayout, v_tt_layout: TensorLayout, output_tt_layout: TensorLayout, //, config: MHAConfig[dtype] = MHAConfig(q_tt_layout.static_shape[2], q_tt_layout.static_shape[3], Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, _use_valid_length: Bool = False, _padded_ndbuffer: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: TileTensor[output_type, output_tt_layout, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q_tt_layout, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[dtype, k_tt_layout, linear_idx_type=k.linear_idx_type, element_size=k.element_size], v: TileTensor[dtype, v_tt_layout, linear_idx_type=v.linear_idx_type, element_size=v.element_size], mask_functor: mask_t, scale: Float32, ctx: DeviceContext, num_partitions: Optional[Int] = None, valid_length: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None)

TileTensor overload of flash_attention. Bridges to LayoutTensor internally.