Skip to main content

Mojo function

mha_decoding_single_batch

mha_decoding_single_batch[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, *, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, depth: Int, num_heads: Int, num_threads: Int, num_pipeline_stages: Int, group: Int = 1, decoding_warp_split_k: Bool = False, sink: Bool = False](q_ptr: UnsafePointer[Scalar[q_type], ImmutAnyOrigin], k: k_t, v: v_t, output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], exp_sum_ptr: UnsafePointer[Scalar[get_accum_type[q_type]()], MutAnyOrigin], qk_max_ptr: UnsafePointer[Scalar[get_accum_type[q_type]()], MutAnyOrigin], scale: Float32, num_keys: Int, num_partitions: Int, mask: mask_t, batch_idx: Int, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), ImmutAnyOrigin]])

Flash attention v2 algorithm.

Was this page helpful?