Skip to main content
Log in

Mojo function

flare_mla_prefill_dispatch

flare_mla_prefill_dispatch[rank: Int, k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, output_type: DType, softmax_type: DType, q_shape: DimList, //, kv_num_heads: Int, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False, q_depth: Int = 192, cache_depth: Int = 576, config: MHAConfig = MHAConfig(type, UInt(q_shape.get[::Int]()), UInt(q_shape.get[::Int]()), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), UInt(2 if _accelerator_arch().__contains__[::Bool,::Origin[$2]](__init__[__mlir_type.!kgen.string](":90")) else 4), UInt(1), FlashAttentionAlgorithm()), _ndbuffer_mha_operand: Bool = False](output: NDBuffer[output_type, rank, origin, shape, strides], q: NDBuffer[type, rank, origin, q_shape, strides], k: k_t, v: v_t, k_rope: k_rope_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: NDBuffer[uint32, 1, origin, shape, strides], max_prompt_len: Int, scale: SIMD[float32, 1], ctx: DeviceContext, softmax_info: OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]] = OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]]({:i1 0, 1}), cache_offsets: OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]] = OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]]({:i1 0, 1}), prev_output: OptionalReg[NDBuffer[output_type, rank, MutableAnyOrigin]] = OptionalReg[NDBuffer[output_type, rank, MutableAnyOrigin]]({:i1 0, 1}), prev_softmax_info: OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]] = OptionalReg[NDBuffer[softmax_type, 3, MutableAnyOrigin]]({:i1 0, 1}))

Was this page helpful?