Skip to main content

Python function

collate_batch

collate_batch()โ€‹

max.pipelines.modeling.dataprocessing.collate_batch(batch, direction=PaddingDirection.RIGHT, pad_value=0, batch_size=None)

source

Generates a single batch tensor from a batch of inputs.

These input tensors may have different lengths. The pad_value will be used to pad out the inputs to the same length.

If batch_size is present, add additional values to the batch up to that size.

Returns:

A matrix with all rows padded to the max sequence length. An array of last token indices (one per batch item).

Return type:

A tuple of

Raises:

Parameters: