Skip to main content

Mojo function

mha_sm100_dispatch

mha_sm100_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig[config.dtype], group: Int, ragged: Bool, sink: Bool, _is_cache_length_accurate: Bool](output: DeviceBuffer[output_type], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, num_rows_q: Int, mask: MaskType, valid_length: DeviceBuffer[DType.uint32], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: Float32, kv_input_row_offsets: OptionalReg[TileTensor[DType.uint32, Layout[#kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](Layout.row_major(VariadicList(-1)).shape), [idx: __mlir_type.index] _int_to_dim(Layout.row_major(VariadicList(-1)).shape[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64])), #kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](Layout.row_major(VariadicList(-1)).stride), [idx: __mlir_type.index] _int_to_dim(Layout.row_major(VariadicList(-1)).stride[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64]))], ImmutAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext, sink_weights: OptionalReg[TileTensor[q_type, Layout[#kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](Layout.row_major(VariadicList(-1)).shape), [idx: __mlir_type.index] _int_to_dim(Layout.row_major(VariadicList(-1)).shape[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64])), #kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](Layout.row_major(VariadicList(-1)).stride), [idx: __mlir_type.index] _int_to_dim(Layout.row_major(VariadicList(-1)).stride[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64]))], ImmutAnyOrigin]])

Was this page helpful?