Skip to main content

Mojo function

TopKSamplingFromProbKernel

TopKSamplingFromProbKernel[probs_origin: ImmutOrigin, probs_shape_types: Variadic[CoordLike], probs_stride_types: Variadic[CoordLike], output_origin: MutOrigin, output_shape_types: Variadic[CoordLike], output_stride_types: Variadic[CoordLike], block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, deterministic: Bool](probs: TileTensor[dtype, probs_origin], output: TileTensor[out_idx_type, output_origin], indices: UnsafePointer[Scalar[out_idx_type], MutExternalOrigin], top_k_arr: UnsafePointer[Scalar[out_idx_type], MutExternalOrigin], top_k_val: Int, d: Int, rng_seed: UInt64, rng_offset: UInt64)

Kernel for top-k sampling from probability distribution.

This kernel performs top-k sampling by:

  1. Using ternary search to find a pivot threshold.
  2. Rejecting samples iteratively until acceptance criteria is met.
  3. Sampling an index using uniform random numbers from Random generator.

Args:

  • probs (TileTensor): Input probability distribution [batch_size, d].
  • output (TileTensor): Output sampled indices [batch_size].
  • indices (UnsafePointer): Optional row indices for batch indexing [batch_size].
  • top_k_arr (UnsafePointer): Optional per-row top_k values [batch_size].
  • top_k_val (Int): Default top_k value if top_k_arr is null.
  • d (Int): Vocabulary size.
  • rng_seed (UInt64): Random seed for Random number generator.
  • rng_offset (UInt64): Random offset for Random number generator.

Was this page helpful?