Mojo function
mla_decoding_single_batch
mla_decoding_single_batch[q_type: DType, k_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, BM: Scalar[DType.uint], BN: Scalar[DType.uint], BK: Scalar[DType.uint], WM: Scalar[DType.uint], WN: Scalar[DType.uint], depth: Scalar[DType.uint], depth_v: Scalar[DType.uint], num_heads: Scalar[DType.uint], num_threads: Scalar[DType.uint], num_pipeline_stages: Scalar[DType.uint], group: Scalar[DType.uint] = 1, use_score_mod: Bool = False, decoding_warp_split_k: Bool = False](q_ptr: UnsafePointer[Scalar[q_type], MutAnyOrigin], k: k_t, output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], exp_sum_ptr: UnsafePointer[Scalar[get_accum_type[q_type]()], MutAnyOrigin], qk_max_ptr: UnsafePointer[Scalar[get_accum_type[q_type]()], MutAnyOrigin], scale: Float32, num_keys: Scalar[DType.uint], num_partitions: Scalar[DType.uint], max_cache_valid_length: Scalar[DType.uint], mask: mask_t, score_mod: score_mod_t, batch_idx: Int)
Flash attention v2 algorithm.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!