Mojo function
mla_sm100_prefill_per_token_scale
mla_sm100_prefill_per_token_scale[output_dtype: DType, q_dtype: DType, rope_dtype: DType, scale_dtype: DType, KType: MHAOperand, VType: MHAOperand, KRopeType: MHAOperand, MaskType: MHAMask, MaxPromptLenType: OptionallyStaticInt, //, config: MHAConfig[config.dtype], group: Int, q_depth: Int, cache_depth: Int, _ndbuffer_mha_operand: Bool](output: TileTensor[output_dtype, output.LayoutType, output.origin, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q_nope: TileTensor[q_dtype, q_nope.LayoutType, q_nope.origin, linear_idx_type=q_nope.linear_idx_type, element_size=q_nope.element_size], q_rope: LayoutTensor[rope_dtype, q_rope.layout, q_rope.origin, element_layout=q_rope.element_layout, layout_int_type=q_rope.layout_int_type, linear_idx_type=q_rope.linear_idx_type, masked=q_rope.masked, alignment=q_rope.alignment], q_scale: LayoutTensor[scale_dtype, q_scale.layout, q_scale.origin, element_layout=q_scale.element_layout, layout_int_type=q_scale.layout_int_type, linear_idx_type=q_scale.linear_idx_type, masked=q_scale.masked, alignment=q_scale.alignment], k_nope: KType, k_rope: KRopeType, v: VType, mask_functor: MaskType, valid_length: TileTensor[DType.uint32, valid_length.LayoutType, valid_length.origin, linear_idx_type=valid_length.linear_idx_type, element_size=valid_length.element_size], max_prompt_len: MaxPromptLenType, scale: Float32, batch_size: Int, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!