Mojo function
flash_attention_dispatch
flash_attention_dispatch[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, dtype: DType, q_layout: Layout, //, kv_num_heads: Int, config: MHAConfig[dtype] = MHAConfig(SIMD(Int[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), SIMD(Int[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), Optional(None), Optional(None), Optional(None), Optional(None), Optional(None), SIMD(4), SIMD(1), FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), ragged: Bool = False, sink: Bool = False, _is_flash_attention_applicable: Bool = True, _is_cache_length_accurate: Bool = False, _use_valid_length: Bool = True, _padded_ndbuffer: Bool = False, decoding_warp_split_k: Bool = False](output: LayoutTensor[output.dtype, output.layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: k_t, v: v_t, mask_functor: mask_t, max_prompt_len: Int, max_cache_valid_length: Int, scale: Float32, is_token_generation: Bool, ctx: DeviceContext, valid_length: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None, num_partitions: Optional[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None, decode_dispatch_metadata: OptionalReg[MHADecodeDispatchMetadata] = None)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!