Skip to main content
Log in

Mojo function

mla_prefill_single_batch

mla_prefill_single_batch[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, config: MHAConfig, group: Int = 1, q_depth: Int = 192, cache_depth: Int = 576, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False](q_ptr: UnsafePointer[SIMD[q_type, 1]], k: k_t, v: v_t, k_rope: k_rope_t, output_ptr: UnsafePointer[SIMD[output_type, 1]], softmax_info_ptr: UnsafePointer[SIMD[get_accum_type[::DType,::DType](), 1]], prev_output_ptr: UnsafePointer[SIMD[output_type, 1]], prev_softmax_info_ptr: UnsafePointer[SIMD[get_accum_type[::DType,::DType](), 1]], scale: SIMD[float32, 1], seq_len: Int, max_seq_len: Int, start_pos: SIMD[uint32, 1], cache_start_pos: SIMD[uint32, 1], num_keys: Int, mask: mask_t, score_mod: score_mod_t, batch_idx: Int)

MLA for encoding where seqlen > 1.

Was this page helpful?