Skip to main content

Python function

batch_padded_tokens_and_mask

batch_padded_tokens_and_mask()

max.pipelines.modeling.dataprocessing.batch_padded_tokens_and_mask(start_pos, tokens)

source

Batches input tokens and computes a batched attention mask.

Parameters:

  • start_pos (list[int]) – index into the end of the KV cache for each batch item.
  • tokens (list[ndarray[tuple[Any, ...], dtype[int64]]]) – unpadded input tokens for this batch.

Returns:

A tuple of (batched tokens, last token indices per batch item, attention mask).

Return type:

tuple[ndarray[tuple[Any, …], dtype[integer[Any]]], ndarray[tuple[Any, …], dtype[integer[Any]]], ndarray[tuple[Any, …], dtype[float32]]]