Mojo function
TopKSamplingFromProbKernel
TopKSamplingFromProbKernel[block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, probs_layout: Layout, output_layout: Layout, deterministic: Bool](probs: LayoutTensor[dtype, probs_layout, MutAnyOrigin], output: LayoutTensor[out_idx_type, output_layout, MutAnyOrigin], indices: LegacyUnsafePointer[Scalar[out_idx_type]], top_k_arr: LegacyUnsafePointer[Scalar[out_idx_type]], top_k_val: Int, d: Int, rng_seed: UInt64, rng_offset: UInt64)
Kernel for top-k sampling from probability distribution.
This kernel performs top-k sampling by:
- Using ternary search to find a pivot threshold.
- Rejecting samples iteratively until acceptance criteria is met.
- Sampling an index using uniform random numbers from Random generator.
Args:
- probs (
LayoutTensor): Input probability distribution [batch_size, d]. - output (
LayoutTensor): Output sampled indices [batch_size]. - indices (
LegacyUnsafePointer): Optional row indices for batch indexing [batch_size]. - top_k_arr (
LegacyUnsafePointer): Optional per-row top_k values [batch_size]. - top_k_val (
Int): Default top_k value if top_k_arr is null. - d (
Int): Vocabulary size. - rng_seed (
UInt64): Random seed for Random number generator. - rng_offset (
UInt64): Random offset for Random number generator.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!