Mojo function
flash_attention_ragged
flash_attention_ragged[rank: Int, mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, q_shape: DimList, //, use_score_mod: Bool = False, config: MHAConfig = MHAConfig(type, UInt(q_shape.get[(rank + -2)]()), UInt(q_shape.get[(rank + -1)]()), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), UInt(4), UInt(1), FlashAttentionAlgorithm(-1), OptionalReg[UInt]({:i1 0, 1}), TensorMapSwizzle(3)), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: NDBuffer[dtype, rank, origin, shape, strides], q: NDBuffer[type, rank, origin, q_shape, strides], k: NDBuffer[dtype, rank, origin, shape, strides], v: NDBuffer[dtype, rank, origin, shape, strides], input_row_offsets: ManagedTensorSlice[IOSpec[True, IO(-1)](), static_spec=StaticTensorSpec.create_unknown[DType.uint32, 1]()], max_prompt_len: NDBuffer[DType.uint32, 1, origin, shape, strides], mask_functor: mask_t, score_mod_functor: score_mod_t, scale: Float32, ctx: DeviceContext, num_partitions: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}))
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!