Mojo function
TopKSamplingFromProbKernel
TopKSamplingFromProbKernel[probs_origin: ImmutOrigin, probs_shape_types: Variadic[CoordLike], probs_stride_types: Variadic[CoordLike], output_origin: MutOrigin, output_shape_types: Variadic[CoordLike], output_stride_types: Variadic[CoordLike], block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, deterministic: Bool](probs: TileTensor[dtype, probs_origin], output: TileTensor[out_idx_type, output_origin], indices: UnsafePointer[Scalar[out_idx_type], MutExternalOrigin], top_k_arr: UnsafePointer[Scalar[out_idx_type], MutExternalOrigin], top_k_val: Int, d: Int, rng_seed: UInt64, rng_offset: UInt64)
Kernel for top-k sampling from probability distribution.
This kernel performs top-k sampling by:
- Using ternary search to find a pivot threshold.
- Rejecting samples iteratively until acceptance criteria is met.
- Sampling an index using uniform random numbers from Random generator.
Args:
- probs (
TileTensor): Input probability distribution [batch_size, d]. - output (
TileTensor): Output sampled indices [batch_size]. - indices (
UnsafePointer): Optional row indices for batch indexing [batch_size]. - top_k_arr (
UnsafePointer): Optional per-row top_k values [batch_size]. - top_k_val (
Int): Default top_k value if top_k_arr is null. - d (
Int): Vocabulary size. - rng_seed (
UInt64): Random seed for Random number generator. - rng_offset (
UInt64): Random offset for Random number generator.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!