For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
flash_attention
def flash_attention[dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(Int[IntTuple](q_layout.shape[2]), Int[IntTuple](q_layout.shape[3]), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: LayoutTensor[element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: LayoutTensor[element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], mask: LayoutTensor[element_layout=mask.element_layout, layout_int_type=mask.layout_int_type, linear_idx_type=mask.linear_idx_type, masked=mask.masked, alignment=mask.alignment], scale: Float32, context: DeviceContext, num_partitions: Optional[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None)
def flash_attention[cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(Int[IntTuple](q_layout.shape[(q_layout.rank() - 2)]), Int[IntTuple](q_layout.shape[(q_layout.rank() - 1)]), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), ragged: Bool = False, sink: Bool = False, decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: LayoutTensor[element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: cache_t, v: cache_t, mask_functor: mask_t, valid_length: LayoutTensor[DType.uint32, element_layout=valid_length.element_layout, layout_int_type=valid_length.layout_int_type, linear_idx_type=valid_length.linear_idx_type, masked=valid_length.masked, alignment=valid_length.alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: Optional[Int] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None, num_partitions: Optional[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None, decode_dispatch_metadata: OptionalReg[MHADecodeDispatchMetadata] = None)
Flash attention 2 algorithm. Compute: (1) Transpose (Q) BSHD -> BHSD; (2) Transpose (K) BSHD -> BHSD; (3) Transpose (V) BSHD -> BHSD; (4) P = Bmm(Q, K), P is also called "score"; (5) P = P * scale + mask; (6) P = softmax(P); (7) O = Bmm(P, V) (8) Output = Transpose(O).
B, S, H, D denote batch size, sequence length, head count and depth, respectively. (1), (2), (3) happens while loading the data into shared memory. (8) happens when writing output to global memory.
All inputs (query, key, and value) must have BSHD layout. The mask can be BSS or BHSS.
This kernel also handles grouped attention optimization. In this case the shape of K and V are BShD where h = H / num_groups.
This kernels handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid_length argument.
def flash_attention[mask_t: MHAMask, dtype: DType, q_layout: Layout, //, config: MHAConfig[dtype] = MHAConfig(Int[IntTuple](q_layout.shape[2]), Int[IntTuple](q_layout.shape[3]), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, _use_valid_length: Bool = False, _padded_ndbuffer: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: LayoutTensor[element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: LayoutTensor[element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], mask_functor: mask_t, scale: Float32, ctx: DeviceContext, num_partitions: Optional[Int] = None, valid_length: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None)
def flash_attention[mask_t: MHAMask, dtype: DType, output_type: DType, q_tt_layout: TensorLayout, k_tt_layout: TensorLayout, v_tt_layout: TensorLayout, output_tt_layout: TensorLayout, //, config: MHAConfig[dtype] = MHAConfig(q_tt_layout.static_shape[2], q_tt_layout.static_shape[3], Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, _use_valid_length: Bool = False, _padded_ndbuffer: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: TileTensor[output_type, output_tt_layout, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q_tt_layout, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[dtype, k_tt_layout, linear_idx_type=k.linear_idx_type, element_size=k.element_size], v: TileTensor[dtype, v_tt_layout, linear_idx_type=v.linear_idx_type, element_size=v.element_size], mask_functor: mask_t, scale: Float32, ctx: DeviceContext, num_partitions: Optional[Int] = None, valid_length: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin]] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None)
TileTensor overload of flash_attention. Bridges to LayoutTensor internally.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!