Mojo function
flash_attention_split_kv
flash_attention_split_kv[dtype: DType, rank: Int, mask_rank: Int, //, input_k_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_v_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_k_cache_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_v_cache_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_mask_fn: def[simd_width: Int, mask_rank: Int](IndexList[mask_rank]) capturing -> SIMD[dtype, simd_width]](q: LayoutTensor[dtype, q.layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k_shape: IndexList[rank], v_shape: IndexList[rank], k_cache_shape: IndexList[(rank + 1)], v_cache_shape: IndexList[(rank + 1)], mask_shape: IndexList[mask_rank], output: LayoutTensor[dtype, output.layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], scale: Float32)
Variant of flash attention that takes the previous KV cache input_{k,v}_cache_fn and the current KV tensors input_k_fn and input_v_fn as separate arguments.
This works around the fact that fusion can't currently look through concat.
So this kernel does an in-place concat fusion by changing the input lambdas
input_{k,v}_cache_fn_wrapper to take previous sequence KV elements from
the KV cache, and current KV elements from tensors k and v.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!