Skip to main content

Mojo function

TopKSamplingFromProbKernel

TopKSamplingFromProbKernel[block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, probs_layout: Layout, output_layout: Layout, deterministic: Bool](probs: LayoutTensor[dtype, probs_layout, MutAnyOrigin], output: LayoutTensor[out_idx_type, output_layout, MutAnyOrigin], indices: LegacyUnsafePointer[Scalar[out_idx_type]], top_k_arr: LegacyUnsafePointer[Scalar[out_idx_type]], top_k_val: Int, d: Int, rng_seed: UInt64, rng_offset: UInt64)

Kernel for top-k sampling from probability distribution.

This kernel performs top-k sampling by:

  1. Using ternary search to find a pivot threshold.
  2. Rejecting samples iteratively until acceptance criteria is met.
  3. Sampling an index using uniform random numbers from Random generator.

Args:

  • probs (LayoutTensor): Input probability distribution [batch_size, d].
  • output (LayoutTensor): Output sampled indices [batch_size].
  • indices (LegacyUnsafePointer): Optional row indices for batch indexing [batch_size].
  • top_k_arr (LegacyUnsafePointer): Optional per-row top_k values [batch_size].
  • top_k_val (Int): Default top_k value if top_k_arr is null.
  • d (Int): Vocabulary size.
  • rng_seed (UInt64): Random seed for Random number generator.
  • rng_offset (UInt64): Random offset for Random number generator.

Was this page helpful?