Mojo function
mla_prefill_single_batch
mla_prefill_single_batch[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, output_type: DType, mask_t: MHAMask, *, config: MHAConfig[config.dtype], group: Int = 1, q_depth: Int = 192, cache_depth: Int = 576](q_ptr: UnsafePointer[Scalar[q_type], q_ptr.origin], k: k_t, v: v_t, k_rope: k_rope_t, output_ptr: UnsafePointer[Scalar[output_type], output_ptr.origin], scale: Float32, seq_len: Int, max_seq_len: Int, start_pos: UInt32, cache_start_pos: UInt32, num_keys: Int, mask: mask_t, batch_idx: Int)
MLA for encoding where seqlen > 1.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!