Skip to main content

Mojo function

TopKTopPSamplingFromProbKernel

TopKTopPSamplingFromProbKernel[ProbsLayoutType: TensorLayout, probs_origin: ImmutOrigin, OutputLayoutType: TensorLayout, output_origin: MutOrigin, block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, deterministic: Bool](probs: TileTensor[dtype, ProbsLayoutType, probs_origin], output: TileTensor[out_idx_type, OutputLayoutType, output_origin], indices: UnsafePointer[Scalar[out_idx_type], MutExternalOrigin], top_k_arr: UnsafePointer[Scalar[out_idx_type], MutExternalOrigin], top_k_val: Int, top_p_arr: UnsafePointer[Float32, MutExternalOrigin], top_p_val: Float32, d: Int, rng_seed: UnsafePointer[UInt64, MutExternalOrigin], rng_offset: UInt64)

Kernel for joint top-k + top-p sampling from probability distribution.

Identical to TopKSamplingFromProbKernel but additionally enforces a nucleus (top-p) constraint: a token is accepted only when both the count of tokens above the pivot is less than k AND the cumulative probability of those tokens is less than p.

When top_p_val = 1.0 and top_p_arr is null, this degrades to top-k-only with zero overhead since sum < 1.0 is always true.

Args:

  • probs (TileTensor): Input probability distribution [batch_size, d].
  • output (TileTensor): Output sampled indices [batch_size].
  • indices (UnsafePointer): Optional row indices for batch indexing [batch_size].
  • top_k_arr (UnsafePointer): Optional per-row top_k values [batch_size].
  • top_k_val (Int): Default top_k value if top_k_arr is null.
  • top_p_arr (UnsafePointer): Optional per-row top_p values [batch_size].
  • top_p_val (Float32): Default top_p value if top_p_arr is null.
  • d (Int): Vocabulary size.
  • rng_seed (UnsafePointer): Pointer to seed value. If non-null, rng_seed[0] is used as the seed. If null, defaults to 0.
  • rng_offset (UInt64): Random offset for Random number generator.

Was this page helpful?