Python function
build_max_lengths_tensor
build_max_lengths_tensor()β
max.nn.build_max_lengths_tensor(num_steps, max_seq_length, max_cache_length)
Builds a [num_steps, 2] uint32 buffer of per-step maximum lengths.
Each row encodes the maximum sequence length and maximum cache length for
that decode step. The first step uses max_seq_length; subsequent steps
use 1 (one new token per step). Cache length increases by 1 each step.
-
Parameters:
-
Returns:
-
A
Bufferof shape[num_steps, 2]and dtypeuint32containing(max_seq_length, max_cache_length)pairs. -
Return type:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!