IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python function

collate_batch

collate_batch()โ€‹

max.pipelines.modeling.dataprocessing.collate_batch(batch, direction=PaddingDirection.RIGHT, pad_value=0, batch_size=None)

source

Generates a single batch tensor from a batch of inputs.

These input tensors may have different lengths. The pad_value will be used to pad out the inputs to the same length.

If batch_size is present, add additional values to the batch up to that size.

Returns:

A matrix with all rows padded to the max sequence length. An array of last token indices (one per batch item).

Return type:

A tuple of

Raises:

Parameters: