Skip to main content

Python function

rejection_sampler

rejection_sampler()โ€‹

max.pipelines.lib.rejection_sampler(device, *, seed=0)

source

Builds a graph that implements speculative decoding rejection sampling.

Accepts or rejects draft tokens using target vs draft probabilities and resamples from the target distribution when rejected.

Parameters:

  • device (DeviceRef) โ€“ Device for the graph.
  • seed (int) โ€“ Random seed for sampling.

Returns:

A graph that takes draft tokens, draft logits, and target logits and outputs accepted tokens and metadata.

Return type:

Graph