Mojo function
flash_attention_ragged
flash_attention_ragged[mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig[type] = MHAConfig(SIMD(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), SIMD(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: LayoutTensor[output.dtype, output.layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[type, q_layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: LayoutTensor[k.dtype, k.layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], v: LayoutTensor[v.dtype, v.layout, v.origin, element_layout=v.element_layout, layout_int_type=v.layout_int_type, linear_idx_type=v.linear_idx_type, masked=v.masked, alignment=v.alignment], input_row_offsets: LayoutTensor[DType.uint32, Layout.row_major(-1), ImmutAnyOrigin], max_prompt_len: LayoutTensor[DType.uint32, max_prompt_len.layout, max_prompt_len.origin, element_layout=max_prompt_len.element_layout, layout_int_type=max_prompt_len.layout_int_type, linear_idx_type=max_prompt_len.linear_idx_type, masked=max_prompt_len.masked, alignment=max_prompt_len.alignment], mask_functor: mask_t, score_mod_functor: score_mod_t, scale: Float32, ctx: DeviceContext, num_partitions: Optional[Int] = None)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!