Skip to main content

Mojo function

topk_softmax_sample

topk_softmax_sample[dtype: DType, out_idx_type: DType, block_size: Int = 1024](ctx: DeviceContext, logits: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sampled_indices: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], top_k_val: Int, temperature_val: Float32 = 1, seed_val: UInt64 = 0, top_k_arr: OptionalReg[LayoutTensor[out_idx_type, Layout.row_major(-1), MutAnyOrigin]] = None, temperature: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, seed: OptionalReg[LayoutTensor[DType.uint64, Layout.row_major(-1), MutAnyOrigin]] = None)

Samples token indices from top-K logits using softmax probabilities.

This kernel performs single-pass top-K selection and categorical sampling:

  1. Finds the k-th largest logit via ternary search.
  2. Computes softmax over top-K elements and caches them in shared memory.
  3. Samples a single token index from the categorical distribution.

Parameters:

  • dtype (DType): The data type of the input logits tensor.
  • out_idx_type (DType): The data type of the output sampled indices.
  • block_size (Int): The number of threads per block (default is 1024).

Args:

  • ctx (DeviceContext): DeviceContext The context for GPU execution.
  • logits (LayoutTensor): Input logits tensor with shape [batch_size, vocab_size].
  • sampled_indices (LayoutTensor): Output buffer for sampled token indices with shape [batch_size].
  • top_k_val (Int): Int Default number of top elements to sample from for each batch element.
  • temperature_val (Float32): Float32 Temperature for softmax scaling (default is 1.0).
  • seed_val (UInt64): UInt64 Seed for the random number generator (default is 0).
  • top_k_arr (OptionalReg): Optional per-batch top-K values. If provided, overrides top_k_val for each batch element.
  • temperature (OptionalReg): Optional per-batch temperature values. If provided, overrides temperature_val for each batch element.
  • seed (OptionalReg): Optional per-batch seed values. If provided, overrides seed_val for each batch element.

Was this page helpful?