Mojo function
flash_attention_split_kv
flash_attention_split_kv[type: DType, rank: Int, mask_rank: Int, //, input_k_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0], input_v_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0], input_k_cache_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0], input_v_cache_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0], input_mask_fn: fn[Int, Int](Index[$1]) capturing -> SIMD[type, $0]](q: NDBuffer[type, rank, origin, shape, strides], k_shape: Index[rank], v_shape: Index[rank], k_cache_shape: Index[(rank + 1)], v_cache_shape: Index[(rank + 1)], mask_shape: Index[mask_rank], output: NDBuffer[type, rank, origin, shape, strides], scale: SIMD[float32, 1])
Variant of flash attention that takes the previous KV cache input_{k,v}_cache_fn
and the current KV tensors input_k_fn
and input_v_fn
as separate arguments.
This works around the fact that fusion can't currently look through concat.
So this kernel does an in-place concat fusion by changing the input lambdas
input_{k,v}_cache_fn_wrapper
to take previous sequence KV elements from
the KV cache, and current KV elements from tensors k
and v
.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!