Skip to main content

Mojo function

mha_sm100_dispatch

mha_sm100_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, ScoreModType: ScoreModTrait, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig, group: Int, use_score_mod: Bool, ragged: Bool, sink: Bool, _is_cache_length_accurate: Bool](output: UnsafePointer[SIMD[output_type, 1]], q_arg: UnsafePointer[SIMD[q_type, 1]], k: KVType, v: KVType, num_rows_q: Int, mask_functor: MaskType, score_mod_functor: ScoreModType, valid_length: ManagedTensorSlice[io_spec, static_spec=static_spec], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: SIMD[float32, 1], kv_input_row_offsets: OptionalReg[NDBuffer[uint32, 1, MutableAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext, sink_weights: OptionalReg[NDBuffer[q_type, 1, MutableAnyOrigin]])

Was this page helpful?