Skip to main content

Mojo function

topk_softmax_sample

topk_softmax_sample[dtype: DType, out_idx_type: DType, block_size: Int = 1024, TopKArrLayoutType: TensorLayout = Layout[*?, *?], TemperatureLayoutType: TensorLayout = Layout[*?, *?], SeedLayoutType: TensorLayout = Layout[*?, *?]](ctx: DeviceContext, logits: TileTensor[dtype, linear_idx_type=logits.linear_idx_type, element_size=logits.element_size], sampled_indices: TileTensor[out_idx_type, linear_idx_type=sampled_indices.linear_idx_type, element_size=sampled_indices.element_size], top_k_val: Int, temperature_val: Float32 = 1, seed_val: UInt64 = UInt64(0), top_k_arr: Optional[TileTensor[out_idx_type, TopKArrLayoutType, MutExternalOrigin]] = None, temperature: Optional[TileTensor[DType.float32, TemperatureLayoutType, MutExternalOrigin]] = None, seed: Optional[TileTensor[DType.uint64, SeedLayoutType, MutExternalOrigin]] = None)

Samples token indices from top-K logits using softmax probabilities.

This kernel performs single-pass top-K selection and categorical sampling:

  1. Finds the k-th largest logit via ternary search.
  2. Computes softmax over top-K elements and caches them in shared memory.
  3. Samples a single token index from the categorical distribution.

Parameters:

  • ​dtype (DType): The data type of the input logits tensor.
  • ​out_idx_type (DType): The data type of the output sampled indices.
  • ​block_size (Int): The number of threads per block (default is 1024).
  • ​TopKArrLayoutType (TensorLayout): The layout type of the optional top_k_arr tensor.
  • ​TemperatureLayoutType (TensorLayout): The layout type of the optional temperature tensor.
  • ​SeedLayoutType (TensorLayout): The layout type of the optional seed tensor.

Args: