Skip to main content

Python function

compute_log_probabilities_ragged

compute_log_probabilities_ragged()

max.pipelines.lib.log_probabilities.compute_log_probabilities_ragged(device, model, *, input_row_offsets, logits, next_token_logits, tokens, sampled_tokens, batch_top_n, batch_echo)

source

Computes the log probabilities for ragged model outputs.

Parameters:

  • device (Device) – Device on which to do the bulk of the log probabilities computation. A small amount of computation still occurs on the host regardless of this setting.
  • model (Model) – A compiled version of a graph from the ‘log_probabilities_ragged_graph’ function.
  • input_row_offsets (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – Token offsets into token-indexed buffers, by batch index. Should have 1 more element than there are batches (batch n is token indices [input_row_offsets[n], input_row_offsets[n+1])).
  • logits (Buffer | None) – (tokens, vocab_dim) tensor full of tensor logits. Token dimension mapped to batches using input_row_offsets. May be omitted only if all ‘batch_echo’ values are False.
  • next_token_logits (Buffer) – (batch_dim, vocab_dim) tensor full of tensor logits for the next token in each batch item.
  • tokens (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – (total_tokens,) flat token array for the batch; indices per batch given by input_row_offsets.
  • sampled_tokens (ndarray[tuple[Any, ...], dtype[integer[Any]]]) – (batch_dim,) tensor of sampled token per batch
  • batch_top_n (Sequence[int]) – Number of top log probabilities to return per input in the batch. For any element where top_n == 0, the LogProbabilities is skipped.
  • batch_echo (Sequence[bool]) – Whether to include input tokens in the returned log probabilities.

Returns:

Computed log probabilities for each item in the batch.

Return type:

list[LogProbabilities | None]