Skip to main content
Log in

Mojo function

mha_decoding_single_batch_pipelined

mha_decoding_single_batch_pipelined[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, BM: UInt, BN: UInt, BK: UInt, WM: UInt, WN: UInt, depth: UInt, num_heads: UInt, num_threads: UInt, num_pipeline_stages: UInt, group: UInt = UInt(1), use_score_mod: Bool = False, decoding_warp_split_k: Bool = False](q_ptr: UnsafePointer[SIMD[q_type, 1]], k: k_t, v: v_t, output_ptr: UnsafePointer[SIMD[output_type, 1]], exp_sum_ptr: UnsafePointer[SIMD[get_accum_type[::DType,::DType](), 1]], qk_max_ptr: UnsafePointer[SIMD[get_accum_type[::DType,::DType](), 1]], scale: SIMD[float32, 1], num_keys: UInt, num_partitions: UInt, max_cache_valid_length: UInt, mask: mask_t, score_mod: score_mod_t, batch_idx: Int)

Flash attention v2 algorithm.

Was this page helpful?