Skip to main content

Mojo function

mha_decoding_single_batch

mha_decoding_single_batch[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, BM: Scalar[DType.uindex], BN: Scalar[DType.uindex], BK: Scalar[DType.uindex], WM: Scalar[DType.uindex], WN: Scalar[DType.uindex], depth: Scalar[DType.uindex], num_heads: Scalar[DType.uindex], num_threads: Scalar[DType.uindex], num_pipeline_stages: Scalar[DType.uindex], group: Scalar[DType.uindex] = 1, use_score_mod: Bool = False, decoding_warp_split_k: Bool = False, sink: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], scale: Float32, num_keys: Scalar[DType.uindex], num_partitions: Scalar[DType.uindex], max_cache_valid_length: Scalar[DType.uindex], mask: mask_t, score_mod: score_mod_t, batch_idx: Int, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]])

Flash attention v2 algorithm.

Was this page helpful?