Mojo function
topk_softmax_sample
topk_softmax_sample[dtype: DType, out_idx_type: DType, block_size: Int = 1024, top_k_arr_shape_types: Variadic[CoordLike] = RuntimeInt[DType.int64], top_k_arr_stride_types: Variadic[CoordLike] = ComptimeInt[1], temperature_shape_types: Variadic[CoordLike] = RuntimeInt[DType.int64], temperature_stride_types: Variadic[CoordLike] = ComptimeInt[1], seed_shape_types: Variadic[CoordLike] = RuntimeInt[DType.int64], seed_stride_types: Variadic[CoordLike] = ComptimeInt[1]](ctx: DeviceContext, logits: TileTensor[dtype, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], sampled_indices: TileTensor[out_idx_type, origin, linear_idx_type=linear_idx_type, element_shape_types=element_shape_types], top_k_val: Int, temperature_val: Float32 = 1, seed_val: UInt64 = 0, top_k_arr: Optional[TileTensor[out_idx_type, MutExternalOrigin]] = None, temperature: Optional[TileTensor[DType.float32, MutExternalOrigin]] = None, seed: Optional[TileTensor[DType.uint64, MutExternalOrigin]] = None)
Samples token indices from top-K logits using softmax probabilities.
This kernel performs single-pass top-K selection and categorical sampling:
- Finds the k-th largest logit via ternary search.
- Computes softmax over top-K elements and caches them in shared memory.
- Samples a single token index from the categorical distribution.
Parameters:
- dtype (
DType): The data type of the input logits tensor. - out_idx_type (
DType): The data type of the output sampled indices. - block_size (
Int): The number of threads per block (default is 1024). - top_k_arr_shape_types (
Variadic): The shape types of the optional top_k_arr tensor. - top_k_arr_stride_types (
Variadic): The stride types of the optional top_k_arr tensor. - temperature_shape_types (
Variadic): The shape types of the optional temperature tensor. - temperature_stride_types (
Variadic): The stride types of the optional temperature tensor. - seed_shape_types (
Variadic): The shape types of the optional seed tensor. - seed_stride_types (
Variadic): The stride types of the optional seed tensor.
Args:
- ctx (
DeviceContext): DeviceContext The context for GPU execution. - logits (
TileTensor): Input logits tensor with shape [batch_size, vocab_size]. - sampled_indices (
TileTensor): Output buffer for sampled token indices with shape [batch_size]. - top_k_val (
Int): Int Default number of top elements to sample from for each batch element. - temperature_val (
Float32): Float32 Temperature for softmax scaling (default is 1.0). - seed_val (
UInt64): UInt64 Seed for the random number generator (default is 0). - top_k_arr (
Optional): Optional per-batch top-K values. If provided, overrides top_k_val for each batch element. - temperature (
Optional): Optional per-batch temperature values. If provided, overrides temperature_val for each batch element. - seed (
Optional): Optional per-batch seed values. If provided, overrides seed_val for each batch element.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!