Skip to main content

Mojo function

topk_softmax_sample

topk_softmax_sample[dtype: DType, out_idx_type: DType, block_size: Int = 1024, TopKArrLayoutType: TensorLayout = Layout[RuntimeInt[DType.int64], ComptimeInt[1]], TemperatureLayoutType: TensorLayout = Layout[RuntimeInt[DType.int64], ComptimeInt[1]], SeedLayoutType: TensorLayout = Layout[RuntimeInt[DType.int64], ComptimeInt[1]]](ctx: DeviceContext, logits: TileTensor[dtype, logits.LayoutType, logits.origin, linear_idx_type=logits.linear_idx_type, element_shape_types=logits.element_shape_types], sampled_indices: TileTensor[out_idx_type, sampled_indices.LayoutType, sampled_indices.origin, linear_idx_type=sampled_indices.linear_idx_type, element_shape_types=sampled_indices.element_shape_types], top_k_val: Int, temperature_val: Float32 = 1, seed_val: UInt64 = 0, top_k_arr: Optional[TileTensor[out_idx_type, TopKArrLayoutType, MutExternalOrigin]] = None, temperature: Optional[TileTensor[DType.float32, TemperatureLayoutType, MutExternalOrigin]] = None, seed: Optional[TileTensor[DType.uint64, SeedLayoutType, MutExternalOrigin]] = None)

Samples token indices from top-K logits using softmax probabilities.

This kernel performs single-pass top-K selection and categorical sampling:

  1. Finds the k-th largest logit via ternary search.
  2. Computes softmax over top-K elements and caches them in shared memory.
  3. Samples a single token index from the categorical distribution.

Parameters:

  • dtype (DType): The data type of the input logits tensor.
  • out_idx_type (DType): The data type of the output sampled indices.
  • block_size (Int): The number of threads per block (default is 1024).
  • TopKArrLayoutType (TensorLayout): The layout type of the optional top_k_arr tensor.
  • TemperatureLayoutType (TensorLayout): The layout type of the optional temperature tensor.
  • SeedLayoutType (TensorLayout): The layout type of the optional seed tensor.

Args:

  • ctx (DeviceContext): DeviceContext The context for GPU execution.
  • logits (TileTensor): Input logits tensor with shape [batch_size, vocab_size].
  • sampled_indices (TileTensor): Output buffer for sampled token indices with shape [batch_size].
  • top_k_val (Int): Int Default number of top elements to sample from for each batch element.
  • temperature_val (Float32): Float32 Temperature for softmax scaling (default is 1.0).
  • seed_val (UInt64): UInt64 Seed for the random number generator (default is 0).
  • top_k_arr (Optional): Optional per-batch top-K values. If provided, overrides top_k_val for each batch element.
  • temperature (Optional): Optional per-batch temperature values. If provided, overrides temperature_val for each batch element.
  • seed (Optional): Optional per-batch seed values. If provided, overrides seed_val for each batch element.

Was this page helpful?