Mojo function
flash_attention
flash_attention[type: DType, rank: Int, mask_rank: Int, //, input_k_fn: fn[Int, Int](IndexList[$1]) capturing -> SIMD[type, $0], input_v_fn: fn[Int, Int](IndexList[$1]) capturing -> SIMD[type, $0], input_mask_fn: fn[Int, Int](IndexList[$1]) capturing -> SIMD[type, $0]](q: NDBuffer[type, rank, origin, shape, strides], k_shape: IndexList[rank], v_shape: IndexList[rank], mask_shape: IndexList[mask_rank], output: NDBuffer[type, rank, origin, shape, strides], scale: SIMD[float32, 1])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!