Skip to main content

Mojo function

TopKSamplingFromProbKernel

TopKSamplingFromProbKernel[ProbsLayoutType: TensorLayout, probs_origin: ImmutOrigin, OutputLayoutType: TensorLayout, output_origin: MutOrigin, block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, deterministic: Bool](probs: TileTensor[dtype, ProbsLayoutType, probs_origin], output: TileTensor[out_idx_type, OutputLayoutType, output_origin], indices: Optional[UnsafePointer[Scalar[out_idx_type], MutExternalOrigin]], top_k_arr: Optional[UnsafePointer[Scalar[out_idx_type], MutExternalOrigin]], top_k_val: Int, d: Int, rng_seed: UInt64, rng_offset: UInt64)

Kernel for top-k sampling from probability distribution.

This kernel performs top-k sampling by:

  1. Using ternary search to find a pivot threshold.
  2. Rejecting samples iteratively until acceptance criteria is met.
  3. Sampling an index using uniform random numbers from Random generator.

Args: