Skip to main content

Python module

sampling

rejection_sampler()

max.pipelines.lib.sampling.sampling.rejection_sampler(device, *, seed=0)

Builds a graph that implements speculative decoding rejection sampling.

Accepts or rejects draft tokens using target vs draft probabilities and resamples from the target distribution when rejected.

Parameters:

  • device (DeviceRef) – Device for the graph.
  • seed (int) – Random seed for sampling.

Returns:

A graph that takes draft tokens, draft logits, and target logits and outputs accepted tokens and metadata.

Return type:

Graph

rejection_sampler_with_residuals()

max.pipelines.lib.sampling.sampling.rejection_sampler_with_residuals(device, *, seed=0, debug=False)

Builds a rejection sampler with residual sampling for speculative decoding.

Computes acceptance ratios for draft tokens, finds first rejection, samples from residual distribution (target - draft), and generates bonus tokens.

Parameters:

Return type:

Graph

token_sampler()

max.pipelines.lib.sampling.sampling.token_sampler(sampling_config, device, return_logits=False)

Builds a sampling graph that samples tokens from logits.

Parameters:

  • sampling_config (SamplingConfig) – Sampling configuration (top-k, temperature, etc.).
  • device (DeviceRef) – Device for the graph inputs and ops.
  • return_logits (bool) – Whether the graph should expose logits as an output.

Returns:

A graph that takes logits (and optional penalty inputs) and outputs tokens.

Return type:

Graph

Was this page helpful?