Mojo function
topk_softmax_sample
topk_softmax_sample[dtype: DType, out_idx_type: DType, block_size: Int = 1024](ctx: DeviceContext, logits: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sampled_indices: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], top_k_val: Int, temperature_val: Float32 = 1, seed_val: UInt64 = 0, top_k_arr: OptionalReg[LayoutTensor[out_idx_type, Layout.row_major(-1), MutAnyOrigin]] = None, temperature: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, seed: OptionalReg[LayoutTensor[DType.uint64, Layout.row_major(-1), MutAnyOrigin]] = None)
Samples token indices from top-K logits using softmax probabilities.
This kernel performs single-pass top-K selection and categorical sampling:
- Finds the k-th largest logit via ternary search.
- Computes softmax over top-K elements and caches them in shared memory.
- Samples a single token index from the categorical distribution.
Parameters:
- dtype (
DType): The data type of the input logits tensor. - out_idx_type (
DType): The data type of the output sampled indices. - block_size (
Int): The number of threads per block (default is 1024).
Args:
- ctx (
DeviceContext): DeviceContext The context for GPU execution. - logits (
LayoutTensor): Input logits tensor with shape [batch_size, vocab_size]. - sampled_indices (
LayoutTensor): Output buffer for sampled token indices with shape [batch_size]. - top_k_val (
Int): Int Default number of top elements to sample from for each batch element. - temperature_val (
Float32): Float32 Temperature for softmax scaling (default is 1.0). - seed_val (
UInt64): UInt64 Seed for the random number generator (default is 0). - top_k_arr (
OptionalReg): Optional per-batch top-K values. If provided, overrides top_k_val for each batch element. - temperature (
OptionalReg): Optional per-batch temperature values. If provided, overrides temperature_val for each batch element. - seed (
OptionalReg): Optional per-batch seed values. If provided, overrides seed_val for each batch element.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!