Skip to main content

Python class

TextGenerationInputs

TextGenerationInputs

class max.interfaces.TextGenerationInputs(batches, num_steps, input_tokens=-1, batch_type=BatchType.TG)

source

Bases: PipelineInputs, Generic[TextGenerationContextType]

Input parameters for text generation pipeline operations.

This class encapsulates the batch of contexts and number of steps required for token generation in a single input object, replacing the previous pattern of passing batch and num_steps as separate parameters.

Parameters:

batch_echo

property batch_echo: list[bool]

source

List indicating whether echo is enabled for each context in the batch.

batch_top_log_probs

property batch_top_log_probs: list[int]

source

List of requested top log probabilities per context in the batch.

batch_type

batch_type: BatchType = 'TG'

source

Type of batch.

batches

batches: list[list[TextGenerationContextType]]

source

Variable list of batches, with each batch being a list of contexts.

There can be multiple batches when using data parallelism, in which each batch is mapped to a different device replica.

enable_echo

property enable_echo: bool

source

True if any context in the batch has echo enabled.

enable_log_probs

property enable_log_probs: bool

source

True if any context in the batch requests log probabilities.

flat_batch

property flat_batch: list[TextGenerationContextType]

source

Flattened list of contexts across all replicas.

input_tokens

input_tokens: int = -1

source

Number of input tokens.

num_steps

num_steps: int

source

Number of steps to run for.