IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

topk_wrapper

def topk_wrapper[input_type: DType, index_type: DType, *, is_top_p: Bool, block_size: Int, largest: Bool = True, _test_sort: Bool = False](K: Int, num_elements: Int, num_blocks_per_input: Int, in_buffer: UnsafePointer[Scalar[input_type], ImmutUntrackedOrigin], local_topk_vals: UnsafePointer[Scalar[input_type], MutUntrackedOrigin], local_topk_idxs: UnsafePointer[Scalar[index_type], MutUntrackedOrigin], p_threshold: UnsafePointer[Scalar[input_type], MutUntrackedOrigin], skip_sort: UnsafePointer[Scalar[DType.bool], MutUntrackedOrigin])

Copy of Kernels/mojo/nn/topk.mojo:_topk_stage1 with the addition of max_vals and p_threshold arguments to determine if sorting is needed for top-p/min-p sampling.

Arguments: K: Int - Number of top elements to select per block num_elements: Int - Size of last dimension of input buffer (vocab size) num_blocks_per_input: Int - Number of blocks used to process the input data in_buffer: UnsafePointer[Scalar[input_type]] - Input buffer containing the elements to process local_topk_vals: UnsafePointer[Scalar[input_type]] - Output buffer to store the local top-K values local_topk_idxs: UnsafePointer[Scalar[index_type]] - Output buffer to store the indices of local top-K elements p_threshold: UnsafePointer[Scalar[input_type]] - Threshold for top-p sampling if is_top_p is True else min-p coefficient skip_sort: UnsafePointer[Scalar[DType.bool]] - Output buffer to store whether sorting is needed

Parameters:

  • ​input_type (DType): DType - The data type of the elements.
  • ​index_type (DType): DType - The data type of the output indices.
  • ​is_top_p (Bool): Bool - Whether this if for top-p sampling or min-p sampling.
  • ​block_size (Int): Int - The number of threads per block to use for the kernel.
  • ​largest (Bool): Bool - Whether to find the maximum or minimum value.
  • ​_test_sort (Bool): Bool - An internal test flag to not skip sort if testing.