IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Python function

compute_log_probabilities_ragged

compute_log_probabilities_ragged()

max.pipelines.lib.log_probabilities.compute_log_probabilities_ragged(device, model, *, input_row_offsets, logits, next_token_logits, tokens, sampled_tokens, batch_top_n, batch_echo)

source

Computes the log probabilities for ragged model outputs.

Parameters:

  • device (Device) – Device on which to do the bulk of the log probabilities computation. A small amount of computation still occurs on the host regardless of this setting.
  • model (Model) – A compiled version of a graph from the ‘log_probabilities_ragged_graph’ function.
  • input_row_offsets (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – Token offsets into token-indexed buffers, by batch index. Should have 1 more element than there are batches (batch n is token indices [input_row_offsets[n], input_row_offsets[n+1])).
  • logits (Buffer | None) – (tokens, vocab_dim) tensor full of tensor logits. Token dimension mapped to batches using input_row_offsets. May be omitted only if all ‘batch_echo’ values are False.
  • next_token_logits (Buffer) – (batch_dim, vocab_dim) tensor full of tensor logits for the next token in each batch item.
  • tokens (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – (total_tokens,) flat token array for the batch; indices per batch given by input_row_offsets.
  • sampled_tokens (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – (batch_dim,) tensor of sampled token per batch
  • batch_top_n (Sequence[int]) – Number of top log probabilities to return per input in the batch. For any element where top_n == 0, the LogProbabilities is skipped.
  • batch_echo (Sequence[bool]) – Whether to include input tokens in the returned log probabilities.

Returns:

Computed log probabilities for each item in the batch.

Return type:

list[LogProbabilities | None]