Python module
compute_log_probabilities
compute_log_probabilities()
max.pipelines.nn.compute_log_probabilities.compute_log_probabilities(get_logits_and_samples: Callable[[int, bool], tuple[numpy.ndarray, numpy.ndarray] | None], batch_top_n: list[int], batch_echo: list[bool]) → list[max.pipelines.response.LogProbabilities | None]
Computes the log probabilities.
-
Parameters:
- get_logits_and_samples – Callable that takes the batch index and an
- batch. (echo bool and returns the logits and sampled tokens for that) – Args:
- batch_index is an int between [0, batch_size)
- echo is whether that item was requested to echo the input tokens. Returns (None if batch item is empty):
- Logits should have shape = (n_tokens, vocab_size).
- Sampled tokens should have shape = (n_tokens).
- batch_top_n – Number of top log probabilities to return per input in the batch. For any element where top_n == 0, the LogProbabilities is skipped.
- batch_echo – Whether to include input tokens in the returned log probabilities.
-
Returns:
Computed log probabilities for each item in the batch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!