IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Logit comparison

Logit comparison helps you determine whether a MAX model matches a trusted reference implementation. Instead of checking only if two models produce plausible text, you check if they produce similar logits from the same input. This page describes three metrics you can use to measure similarity: mean absolute error, cosine distance, and KL divergence.

How to measure logit similarity​

Before a language model selects the next token to generate, it produces a vector of logits. Each position in the vector corresponds to a token in the model's vocabulary, and the logit at that position represents the model's relative preference for that token. Comparing logits (rather than tokens) is helpful for model validation because two models can produce the same token even when their logit distributions differ.

When you compare logit values, you don't need to achieve exact equality. Different data types, hardware targets, and kernel implementations can introduce small, insignificant numerical differences in logits. Instead, measure against thresholds that account for the expected numerical variation. While we provide threshold suggestions, you should decide what counts as significant for your own use case.

Mean absolute error​

A literal way to compare logits is to subtract the vectors element by element and average the results. This gives you the mean absolute error (MAE), which you can compare to a tolerance threshold.

In the following equation, a is the MAX logit vector, b is the PyTorch logit vector, N is the vocabulary size, and i indexes a token in the vocabulary.

MAE(a, b) = (1 / N) * sum(abs(a_i - b_i) for i in range(N))

Cosine distance​

Cosine distance measures whether two logit vectors point in similar directions, rather than whether their individual values match exactly. This metric is useful when the vectors differ element by element but still share a similar overall orientation. Aim for a cosine distance of <= 1e-3.

In the following equations, a is the MAX logit vector and b is the PyTorch logit vector.

cos(a, b) = (a Β· b) / (β€–aβ€– * β€–bβ€–)
distance(a, b) = 1 - cos(a, b)

KL divergence​

Kullback-Leibler (KL) divergence measures the difference between the probability distributions of two sets of logits. To compute this metric, first convert the logits to probabilities with softmax, and then apply smoothing to avoid undefined or unstable values.

When you evaluate KL divergence, review both the average and the maximum KL values across the generated positions:

  • A high maximum with a low average indicates one position with a large divergence.
  • A high maximum and a high average indicate a systematic difference.

The maximum KL divergence should typically be <= 1e-2.

In the following equations, p is the PyTorch probability distribution, q is the MAX probability distribution, and N is the vocabulary size.

# Softmax
p_i = exp(torch_logit_i) / sum(exp(torch_logit_j) for all j)
q_i = exp(max_logit_i) / sum(exp(max_logit_j) for all j)

# Smooth the distributions
p_smooth_i = (1 - N * eps) * p_i + eps
q_smooth_i = (1 - N * eps) * q_i + eps

# KL divergence
KL(p_smooth || q_smooth) = sum(
    p_smooth_i * log(p_smooth_i / q_smooth_i) for all i
)

Run a logit comparison script​

The preceding sections define each metric individually. The following Python script computes all three from your custom model. It runs the same prompts through your MAX model and a reference implementation, capturing the logits used to select each generated token across five generation steps. It then reports the metrics so you can compare them against your thresholds.

The script uses PipelineConfig to load your architecture from PIPELINE_REGISTRY, then sends a TextGenerationRequest per prompt. To capture each step's logits, the request's SamplingParams attaches a LogitsProcessor that the pipeline calls before sampling.

from __future__ import annotations

from dataclasses import dataclass

import numpy as np
import torch
from max.pipelines import PIPELINE_REGISTRY, PipelineConfig
from max.pipelines.context import ProcessorInputs, SamplingParams
from max.pipelines.modeling.types import RequestID, TextGenerationRequest
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_PATH = "hugging-face-repo-ID"
CUSTOM_ARCHITECTURES = ["path-to-architecture-module"]
PROMPTS = [
    "The capital of France is",
    "Once upon a time, in a faraway kingdom,",
    "The quick brown fox jumps over the lazy",
]

NUM_STEPS = 5


@dataclass
class LogitComparison:
    prompt: str
    avg_abs_mae: float
    avg_cos_dist: float
    avg_kl_div: float
    max_kl_div: float


def cosine_distance(a: np.ndarray, b: np.ndarray) -> float:
    dot = np.sum(a * b, axis=-1)
    denom = np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)
    return float(np.mean(1.0 - dot / denom))


def smooth_softmax(logits: np.ndarray, eps: float = 1e-10) -> np.ndarray:
    shifted = logits - logits.max(axis=-1, keepdims=True)
    probs = np.exp(shifted)
    probs /= probs.sum(axis=-1, keepdims=True)
    return (1 - logits.shape[-1] * eps) * probs + eps


def kl_divergence(p: np.ndarray, q: np.ndarray) -> np.ndarray:
    return np.sum(p * np.log(p / q), axis=-1)


class CaptureGeneratedLogits:
    """Captures up to `max_steps` next-token logits per request."""

    def __init__(self, max_steps: int) -> None:
        self.max_steps = max_steps
        self.captured: dict[RequestID, list[np.ndarray]] = {}

    def __call__(self, inputs: ProcessorInputs) -> None:
        bucket = self.captured.setdefault(inputs.context.request_id, [])
        if len(bucket) < self.max_steps:
            bucket.append(inputs.logits[-1, :].to_numpy().copy())


def max_logits_for(prompt: str) -> np.ndarray:
    capture = CaptureGeneratedLogits(NUM_STEPS)
    request = TextGenerationRequest(
        request_id=RequestID(),
        model_name=MODEL_PATH,
        prompt=prompt,
        sampling_params=SamplingParams(
            max_new_tokens=NUM_STEPS,
            ignore_eos=True,
            top_k=1,
            logits_processors=[capture],
        ),
    )
    pipeline.generate([request])
    return np.stack(capture.captured[request.request_id]).astype(np.float64)


def torch_logits_for(prompt: str, max_logits: np.ndarray) -> np.ndarray:
    inputs = hf_tokenizer(prompt, return_tensors="pt").to(device)
    step_logits: list[np.ndarray] = []

    for step in range(NUM_STEPS):
        with torch.no_grad():
            outputs = hf_model(**inputs)

        logits = outputs.logits[0, -1, :].to(torch.float64).cpu().numpy()
        step_logits.append(logits)

        next_token = int(np.argmax(max_logits[step]))
        next_token_tensor = torch.tensor(
            [[next_token]], device=inputs["input_ids"].device
        )
        inputs["input_ids"] = torch.cat(
            [inputs["input_ids"], next_token_tensor], dim=-1
        )
        inputs["attention_mask"] = torch.cat(
            [inputs["attention_mask"], torch.ones_like(next_token_tensor)],
            dim=-1,
        )

    return np.stack(step_logits)


def compare_prompt(prompt: str) -> LogitComparison:
    max_logits = max_logits_for(prompt)
    hf_logits = torch_logits_for(prompt, max_logits)

    kl_by_step = kl_divergence(
        smooth_softmax(hf_logits),
        smooth_softmax(max_logits),
    )

    return LogitComparison(
        prompt=prompt,
        avg_abs_mae=float(np.mean(np.abs(max_logits - hf_logits))),
        avg_cos_dist=cosine_distance(max_logits, hf_logits),
        avg_kl_div=float(np.mean(kl_by_step)),
        max_kl_div=float(np.max(kl_by_step)),
    )


def print_report(results: list[LogitComparison]) -> None:
    header = (
        f"{'prompt':<45} {'avg_abs_mae':>12} "
        f"{'avg_cos_dist':>12} {'avg_kl_div':>12} {'max_kl_div':>12}"
    )
    print(header)
    print("-" * len(header))
    for r in results:
        print(
            f"{r.prompt[:45]:<45} {r.avg_abs_mae:>12.3e} "
            f"{r.avg_cos_dist:>12.3e} {r.avg_kl_div:>12.3e} {r.max_kl_div:>12.3e}"
        )


pipeline_config = PipelineConfig(
    model_path=MODEL_PATH,
    runtime={"custom_architectures": CUSTOM_ARCHITECTURES},
)
_, pipeline = PIPELINE_REGISTRY.retrieve(pipeline_config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_PATH, torch_dtype="auto"
).to(device)
hf_model.eval()

results = [compare_prompt(prompt) for prompt in PROMPTS]
print_report(results)

This produces output like:

prompt                                         avg_abs_mae avg_cos_dist   avg_kl_div   max_kl_div
-------------------------------------------------------------------------------------------------
The capital of France is                         4.975e-02    1.157e-04    8.317e-04    1.580e-03
Once upon a time, in a faraway kingdom,          6.707e-02    1.367e-04    7.051e-04    1.915e-03
The quick brown fox jumps over the lazy          6.421e-02    2.010e-04    1.243e-03    2.225e-03

In this table, each row corresponds to one of the test prompts defined in the script. Each column beside the prompt shows either the average of the metrics across the five generation steps or the maximum value across the five generation steps. Compare these numbers against your predetermined thresholds to see if the model passes your validation criteria.

Next steps​

Use these related resources next:

Was this page helpful?