Python function
collate_batch
collate_batch()โ
max.pipelines.modeling.dataprocessing.collate_batch(batch, direction=PaddingDirection.RIGHT, pad_value=0, batch_size=None)
Generates a single batch tensor from a batch of inputs.
These input tensors may have different lengths. The pad_value will be used
to pad out the inputs to the same length.
If batch_size is present, add additional values to the batch up to that
size.
-
Returns:
-
A matrix with all rows padded to the max sequence length. An array of last token indices (one per batch item).
-
Return type:
-
A tuple of
-
Raises:
-
- ValueError โ if the batch is empty.
- NotImplementedError โ if the batch contains anything other than vectors.
-
Parameters:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!