Skip to main content

Mojo function

mha_sm100_depth512_dispatch

mha_sm100_depth512_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig[config.dtype], group: Int, ragged: Bool, _is_cache_length_accurate: Bool](output: DeviceBuffer[output_type], q_arg: UnsafePointer[Scalar[q_type], q_arg.origin], k: KVType, v: KVType, num_rows_q: Int, mask: MaskType, valid_length: UnsafePointer[UInt32, valid_length.origin], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: Float32, kv_input_row_offsets: OptionalReg[TileTensor[DType.uint32, Layout[#kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](product_each(Layout.row_major(-1).shape)), [idx: __mlir_type.index] RuntimeInt[DType.int64]), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[CoordLike], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[Int[IntTuple](product_each(Layout.row_major(-1).shape)[idx])] if (Int[IntTuple](product_each(Layout.row_major(-1).shape)[idx]) != -1) else RuntimeInt[DType.int64])), #kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](product_each(Layout.row_major(-1).stride)), [idx: __mlir_type.index] RuntimeInt[DType.int64]), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[CoordLike], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[Int[IntTuple](product_each(Layout.row_major(-1).stride)[idx])] if (Int[IntTuple](product_each(Layout.row_major(-1).stride)[idx]) != -1) else RuntimeInt[DType.int64]))], ImmutAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext)

Was this page helpful?