Skip to main content
Log in

Mojo function

flare_mla_decoding_dispatch

flare_mla_decoding_dispatch[rank: Int, k_t: MHAOperand, mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, q_shape: DimList, //, kv_num_heads: Int, use_score_mod: Bool = False, config: MHAConfig = MHAConfig(type, UInt(q_shape.get[::Int]()), UInt(q_shape.get[::Int]()), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), OptionalReg[UInt]({:i1 0, 1}), UInt(2 if _accelerator_arch().__contains__[::Bool,::Origin[$2]](__init__[__mlir_type.!kgen.string](":90")) else 4), UInt(1), FlashAttentionAlgorithm()), ragged: Bool = False, _is_cache_length_accurate: Bool = False, _use_valid_length: Bool = True, decoding_warp_split_k: Bool = False](output: NDBuffer[type, rank, origin, shape, strides], q: NDBuffer[type, rank, origin, q_shape, strides], k: k_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: NDBuffer[uint32, 1, origin, shape, strides], max_prompt_len: Int, max_cache_valid_length: Int, scale: SIMD[float32, 1], ctx: DeviceContext, kv_input_row_offsets: OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]] = OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]]({:i1 0, 1}), num_partitions: OptionalReg[Int] = OptionalReg[Int]({:i1 0, 1}))

Was this page helpful?