Skip to main content

Python class

RejectionSampler

RejectionSampler

class max.nn.RejectionSampler(device, top_k=1, top_p=1, temperature=1.0, seed=0, eps=1e-05)

source

Bases: Module

Rejection sampler for speculative decoding verification.

Accepts a draft token when the draft logit for that token does not exceed the target logit by more than eps. Returns (first_rejected_idx, sampled_target_token) - a single recovered token at the first rejected position.

Parameters: