IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

flash_attention_split_kv

def flash_attention_split_kv[dtype: DType, rank: Int, mask_rank: Int, //, input_k_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_v_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_k_cache_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_v_cache_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_mask_fn: def[simd_width: Int, mask_rank: Int](IndexList[mask_rank]) capturing -> SIMD[dtype, simd_width]](q: LayoutTensor[dtype, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k_shape: IndexList[rank], v_shape: IndexList[rank], k_cache_shape: IndexList[(rank + 1)], v_cache_shape: IndexList[(rank + 1)], mask_shape: IndexList[mask_rank], output: LayoutTensor[dtype, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], scale: Float32, ctx: Optional[DeviceContext] = None)

Variant of flash attention that takes the previous KV cache input_{k,v}_cache_fn and the current KV tensors input_k_fn and input_v_fn as separate arguments.

This works around the fact that fusion can't currently look through concat. So this kernel does an in-place concat fusion by changing the input lambdas input_{k,v}_cache_fn_wrapper to take previous sequence KV elements from the KV cache, and current KV elements from tensors k and v.