Skip to main content

Python function

token_sampler

token_sampler()​

max.pipelines.lib.token_sampler(sampling_config, device, return_logits=False)

source

Builds a sampling graph that samples tokens from logits.

Parameters:

  • sampling_config (SamplingConfig) – Sampling configuration (top-k, temperature, etc.).
  • device (DeviceRef) – Device for the graph inputs and ops.
  • return_logits (bool) – Whether the graph should expose logits as an output.

Returns:

A graph that takes logits (and optional penalty inputs) and outputs tokens.

Return type:

Graph