Python class
ModelInputs
ModelInputs
class max.pipelines.ModelInputs(*, kv_cache_inputs=None, lora_ids=None, lora_ranks=None, hidden_states=None)
Bases: object
Base class for model inputs.
Use this class to encapsulate inputs for your model; you may store any number of dataclass fields.
The following example demonstrates how to create a custom inputs class:
@dataclass
class ReplitInputs(ModelInputs):
tokens: Buffer
input_row_offsets: Buffer
tokens = Buffer.zeros((1, 2, 3), DType.int64)
input_row_offsets = Buffer.zeros((1, 1, 1), DType.int64)
# Initialize inputs
inputs = ReplitInputs(tokens=tokens, input_row_offsets=input_row_offsets)
# Access tensors
list(inputs) == [tokens, input_row_offsets] # Output: True-
Parameters:
buffers
Returns positional Buffer inputs for model ABI calls.
hidden_states
Hidden states for a variable number of tokens per sequence.
For data parallel models, this can be a list of Buffers where each Buffer contains hidden states for the sequences assigned to that device.
kv_cache_inputs
kv_cache_inputs: KVCacheInputs[Buffer, Buffer] | None = None
lora_ids
Buffer containing the LoRA ids.
lora_ranks
Buffer containing the LoRA ranks
update()
update(**kwargs)
Updates attributes from keyword arguments (only existing, non-None).
-
Return type:
-
None
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!