Skip to main content

Mojo function

mla_prefill

mla_prefill[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, output_type: DType, mask_t: MHAMask, valid_layout: TensorLayout, config: MHAConfig[config.dtype], group: Int = 128, q_depth: Int = 192, cache_depth: Int = 576, _ndbuffer_mha_operand: Bool = False](q_ptr: UnsafePointer[Scalar[q_type], MutAnyOrigin], k: k_t, v: v_t, k_rope: k_rope_t, output_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin], scale: Float32, batch_size: Int, seq_len_arg: Int, valid_length_tt: TileTensor[DType.uint32, valid_layout, MutAnyOrigin], cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), MutAnyOrigin]], mask: mask_t)

Was this page helpful?