Mojo function
flare_mla_prefill_dispatch
flare_mla_prefill_dispatch[k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, mask_t: MHAMask, dtype: DType, output_type: DType, //, kv_num_heads: Int, config: MHAConfig[dtype], q_depth: Int = 192, cache_depth: Int = 576, _ndbuffer_mha_operand: Bool = False](output: TileTensor[output_type, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: k_t, v: v_t, k_rope: k_rope_t, mask_functor: mask_t, valid_length: TileTensor[DType.uint32, valid_length.LayoutType, valid_length.origin, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], max_prompt_len: Int, scale: Float32, ctx: DeviceContext, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(VariadicList(-1)), MutAnyOrigin]] = None)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!