For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Logit comparison
Logit comparison helps you determine whether a MAX model matches a trusted reference implementation. Instead of checking only if two models produce plausible text, you check if they produce similar logits from the same input. This page describes three metrics you can use to measure similarity: mean absolute error, cosine distance, and KL divergence.
How to measure logit similarityβ
Before a language model selects the next token to generate, it produces a vector of logits. Each position in the vector corresponds to a token in the model's vocabulary, and the logit at that position represents the model's relative preference for that token. Comparing logits (rather than tokens) is helpful for model validation because two models can produce the same token even when their logit distributions differ.
When you compare logit values, you don't need to achieve exact equality. Different data types, hardware targets, and kernel implementations can introduce small, insignificant numerical differences in logits. Instead, measure against thresholds that account for the expected numerical variation. While we provide threshold suggestions, you should decide what counts as significant for your own use case.
Mean absolute errorβ
A literal way to compare logits is to subtract the vectors element by element and average the results. This gives you the mean absolute error (MAE), which you can compare to a tolerance threshold.
In the following equation, a is the MAX logit vector, b is the PyTorch
logit vector, N is the vocabulary size, and i indexes a token in the
vocabulary.
MAE(a, b) = (1 / N) * sum(abs(a_i - b_i) for i in range(N))Cosine distanceβ
Cosine distance measures whether two logit vectors point in similar
directions, rather than whether their individual values match exactly. This
metric is useful when the vectors differ element by element but still share a
similar overall orientation. Aim for a cosine distance of <= 1e-3.
In the following equations, a is the MAX logit vector and b is the
PyTorch logit vector.
cos(a, b) = (a Β· b) / (βaβ * βbβ)
distance(a, b) = 1 - cos(a, b)KL divergenceβ
Kullback-Leibler (KL) divergence measures the difference between the probability distributions of two sets of logits. To compute this metric, first convert the logits to probabilities with softmax, and then apply smoothing to avoid undefined or unstable values.
When you evaluate KL divergence, review both the average and the maximum KL values across the generated positions:
- A high maximum with a low average indicates one position with a large divergence.
- A high maximum and a high average indicate a systematic difference.
The maximum KL divergence should typically be <= 1e-2.
In the following equations, p is the PyTorch probability distribution, q is
the MAX probability distribution, and N is the vocabulary size.
# Softmax
p_i = exp(torch_logit_i) / sum(exp(torch_logit_j) for all j)
q_i = exp(max_logit_i) / sum(exp(max_logit_j) for all j)
# Smooth the distributions
p_smooth_i = (1 - N * eps) * p_i + eps
q_smooth_i = (1 - N * eps) * q_i + eps
# KL divergence
KL(p_smooth || q_smooth) = sum(
p_smooth_i * log(p_smooth_i / q_smooth_i) for all i
)Run a logit comparison scriptβ
The preceding sections define each metric individually. The following Python script computes all three from your custom model. It runs the same prompts through your MAX model and a reference implementation, capturing the logits used to select each generated token across five generation steps. It then reports the metrics so you can compare them against your thresholds.
The script uses
PipelineConfig to
load your architecture from
PIPELINE_REGISTRY,
then sends a
TextGenerationRequest
per prompt. To capture each step's logits, the request's
SamplingParams
attaches a
LogitsProcessor
that the pipeline calls before sampling.
from __future__ import annotations
from dataclasses import dataclass
import numpy as np
import torch
from max.pipelines import PIPELINE_REGISTRY, PipelineConfig
from max.pipelines.context import ProcessorInputs, SamplingParams
from max.pipelines.modeling.types import RequestID, TextGenerationRequest
from transformers import AutoModelForCausalLM, AutoTokenizer
MODEL_PATH = "hugging-face-repo-ID"
CUSTOM_ARCHITECTURES = ["path-to-architecture-module"]
PROMPTS = [
"The capital of France is",
"Once upon a time, in a faraway kingdom,",
"The quick brown fox jumps over the lazy",
]
NUM_STEPS = 5
@dataclass
class LogitComparison:
prompt: str
avg_abs_mae: float
avg_cos_dist: float
avg_kl_div: float
max_kl_div: float
def cosine_distance(a: np.ndarray, b: np.ndarray) -> float:
dot = np.sum(a * b, axis=-1)
denom = np.linalg.norm(a, axis=-1) * np.linalg.norm(b, axis=-1)
return float(np.mean(1.0 - dot / denom))
def smooth_softmax(logits: np.ndarray, eps: float = 1e-10) -> np.ndarray:
shifted = logits - logits.max(axis=-1, keepdims=True)
probs = np.exp(shifted)
probs /= probs.sum(axis=-1, keepdims=True)
return (1 - logits.shape[-1] * eps) * probs + eps
def kl_divergence(p: np.ndarray, q: np.ndarray) -> np.ndarray:
return np.sum(p * np.log(p / q), axis=-1)
class CaptureGeneratedLogits:
"""Captures up to `max_steps` next-token logits per request."""
def __init__(self, max_steps: int) -> None:
self.max_steps = max_steps
self.captured: dict[RequestID, list[np.ndarray]] = {}
def __call__(self, inputs: ProcessorInputs) -> None:
bucket = self.captured.setdefault(inputs.context.request_id, [])
if len(bucket) < self.max_steps:
bucket.append(inputs.logits[-1, :].to_numpy().copy())
def max_logits_for(prompt: str) -> np.ndarray:
capture = CaptureGeneratedLogits(NUM_STEPS)
request = TextGenerationRequest(
request_id=RequestID(),
model_name=MODEL_PATH,
prompt=prompt,
sampling_params=SamplingParams(
max_new_tokens=NUM_STEPS,
ignore_eos=True,
top_k=1,
logits_processors=[capture],
),
)
pipeline.generate([request])
return np.stack(capture.captured[request.request_id]).astype(np.float64)
def torch_logits_for(prompt: str, max_logits: np.ndarray) -> np.ndarray:
inputs = hf_tokenizer(prompt, return_tensors="pt").to(device)
step_logits: list[np.ndarray] = []
for step in range(NUM_STEPS):
with torch.no_grad():
outputs = hf_model(**inputs)
logits = outputs.logits[0, -1, :].to(torch.float64).cpu().numpy()
step_logits.append(logits)
next_token = int(np.argmax(max_logits[step]))
next_token_tensor = torch.tensor(
[[next_token]], device=inputs["input_ids"].device
)
inputs["input_ids"] = torch.cat(
[inputs["input_ids"], next_token_tensor], dim=-1
)
inputs["attention_mask"] = torch.cat(
[inputs["attention_mask"], torch.ones_like(next_token_tensor)],
dim=-1,
)
return np.stack(step_logits)
def compare_prompt(prompt: str) -> LogitComparison:
max_logits = max_logits_for(prompt)
hf_logits = torch_logits_for(prompt, max_logits)
kl_by_step = kl_divergence(
smooth_softmax(hf_logits),
smooth_softmax(max_logits),
)
return LogitComparison(
prompt=prompt,
avg_abs_mae=float(np.mean(np.abs(max_logits - hf_logits))),
avg_cos_dist=cosine_distance(max_logits, hf_logits),
avg_kl_div=float(np.mean(kl_by_step)),
max_kl_div=float(np.max(kl_by_step)),
)
def print_report(results: list[LogitComparison]) -> None:
header = (
f"{'prompt':<45} {'avg_abs_mae':>12} "
f"{'avg_cos_dist':>12} {'avg_kl_div':>12} {'max_kl_div':>12}"
)
print(header)
print("-" * len(header))
for r in results:
print(
f"{r.prompt[:45]:<45} {r.avg_abs_mae:>12.3e} "
f"{r.avg_cos_dist:>12.3e} {r.avg_kl_div:>12.3e} {r.max_kl_div:>12.3e}"
)
pipeline_config = PipelineConfig(
model_path=MODEL_PATH,
runtime={"custom_architectures": CUSTOM_ARCHITECTURES},
)
_, pipeline = PIPELINE_REGISTRY.retrieve(pipeline_config)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
hf_model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH, torch_dtype="auto"
).to(device)
hf_model.eval()
results = [compare_prompt(prompt) for prompt in PROMPTS]
print_report(results)This produces output like:
prompt avg_abs_mae avg_cos_dist avg_kl_div max_kl_div
-------------------------------------------------------------------------------------------------
The capital of France is 4.975e-02 1.157e-04 8.317e-04 1.580e-03
Once upon a time, in a faraway kingdom, 6.707e-02 1.367e-04 7.051e-04 1.915e-03
The quick brown fox jumps over the lazy 6.421e-02 2.010e-04 1.243e-03 2.225e-03In this table, each row corresponds to one of the test prompts defined in the script. Each column beside the prompt shows either the average of the metrics across the five generation steps or the maximum value across the five generation steps. Compare these numbers against your predetermined thresholds to see if the model passes your validation criteria.
Next stepsβ
Use these related resources next:
- To investigate why a model doesn't meet your similarity thresholds, see Debug MAX model accuracy.
- To learn about the stricter validation methods MAX requires for model contribution, see Contributing new model architectures.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!