Skip to main content

Python class

ModelInputs

ModelInputs

class max.pipelines.ModelInputs(*, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)

source

Bases: object

Base class for model inputs.

Use this class to encapsulate inputs for your model; you may store any number of dataclass fields.

The following example demonstrates how to create a custom inputs class:

@dataclass
class ReplitInputs(ModelInputs):
    tokens: Buffer
    input_row_offsets: Buffer

tokens = Buffer.zeros((1, 2, 3), DType.int64)
input_row_offsets = Buffer.zeros((1, 1, 1), DType.int64)

# Initialize inputs
inputs = ReplitInputs(tokens=tokens, input_row_offsets=input_row_offsets)

# Access tensors
list(inputs) == [tokens, input_row_offsets]  # Output: True

Parameters:

buffers

property buffers: tuple[Buffer, ...]

source

Returns positional Buffer inputs for model ABI calls.

hidden_states

hidden_states: Buffer | list[Buffer] | None = None

source

Hidden states for a variable number of tokens per sequence.

For data parallel models, this can be a list of Buffers where each Buffer contains hidden states for the sequences assigned to that device.

kv_cache_inputs

kv_cache_inputs: KVCacheInputs[Buffer, Buffer] | None = None

source

lora_ids

lora_ids: Buffer | None = None

source

Buffer containing the LoRA ids.

lora_ranks

lora_ranks: Buffer | None = None

source

Buffer containing the LoRA ranks

update()

update(**kwargs)

source

Updates attributes from keyword arguments (only existing, non-None).

Return type:

None