Skip to main content

Python function

build_max_lengths_tensor

build_max_lengths_tensor()​

max.nn.build_max_lengths_tensor(num_steps, max_seq_length, max_cache_length)

source

Builds a [num_steps, 2] uint32 buffer of per-step maximum lengths.

Each row encodes the maximum sequence length and maximum cache length for that decode step. The first step uses max_seq_length; subsequent steps use 1 (one new token per step). Cache length increases by 1 each step.

Parameters:

  • num_steps (int) – The number of decode steps to pre-compute lengths for.
  • max_seq_length (int) – The maximum sequence length for the first step.
  • max_cache_length (int) – The maximum cache length for the first step.

Returns:

A Buffer of shape [num_steps, 2] and dtype uint32 containing (max_seq_length, max_cache_length) pairs.

Return type:

Buffer