Skip to main content

Mojo function

topk_softmax_sample

topk_softmax_sample[dtype: DType, out_idx_type: DType, block_size: Int = 1024, top_k_arr_shape_types: Variadic[CoordLike] = RuntimeInt[DType.int64], top_k_arr_stride_types: Variadic[CoordLike] = ComptimeInt[1], temperature_shape_types: Variadic[CoordLike] = RuntimeInt[DType.int64], temperature_stride_types: Variadic[CoordLike] = ComptimeInt[1], seed_shape_types: Variadic[CoordLike] = RuntimeInt[DType.int64], seed_stride_types: Variadic[CoordLike] = ComptimeInt[1]](ctx: DeviceContext, logits: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], sampled_indices: TileTensor[out_idx_type, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], top_k_val: Int, temperature_val: Float32 = 1, seed_val: UInt64 = 0, top_k_arr: Optional[TileTensor[out_idx_type, MutExternalOrigin]] = None, temperature: Optional[TileTensor[DType.float32, MutExternalOrigin]] = None, seed: Optional[TileTensor[DType.uint64, MutExternalOrigin]] = None)

Samples token indices from top-K logits using softmax probabilities.

This kernel performs single-pass top-K selection and categorical sampling:

  1. Finds the k-th largest logit via ternary search.
  2. Computes softmax over top-K elements and caches them in shared memory.
  3. Samples a single token index from the categorical distribution.

Parameters:

  • dtype (DType): The data type of the input logits tensor.
  • out_idx_type (DType): The data type of the output sampled indices.
  • block_size (Int): The number of threads per block (default is 1024).
  • top_k_arr_shape_types (Variadic): The shape types of the optional top_k_arr tensor.
  • top_k_arr_stride_types (Variadic): The stride types of the optional top_k_arr tensor.
  • temperature_shape_types (Variadic): The shape types of the optional temperature tensor.
  • temperature_stride_types (Variadic): The stride types of the optional temperature tensor.
  • seed_shape_types (Variadic): The shape types of the optional seed tensor.
  • seed_stride_types (Variadic): The stride types of the optional seed tensor.

Args:

  • ctx (DeviceContext): DeviceContext The context for GPU execution.
  • logits (TileTensor): Input logits tensor with shape [batch_size, vocab_size].
  • sampled_indices (TileTensor): Output buffer for sampled token indices with shape [batch_size].
  • top_k_val (Int): Int Default number of top elements to sample from for each batch element.
  • temperature_val (Float32): Float32 Temperature for softmax scaling (default is 1.0).
  • seed_val (UInt64): UInt64 Seed for the random number generator (default is 0).
  • top_k_arr (Optional): Optional per-batch top-K values. If provided, overrides top_k_val for each batch element.
  • temperature (Optional): Optional per-batch temperature values. If provided, overrides temperature_val for each batch element.
  • seed (Optional): Optional per-batch seed values. If provided, overrides seed_val for each batch element.

Was this page helpful?