IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

block_select_topk

def block_select_topk[T: DType, out_idx_type: DType, largest: Bool = True](scores: UnsafePointer[Scalar[T], MutAnyOrigin], num_blocks: Int, k: Int, out_idxs: UnsafePointer[Scalar[out_idx_type], MutAnyOrigin])

Select the top-k block indices from one row of block scores.

Cooperative across the whole thread block: on each of k_batch = min(k, num_blocks) iterations, every thread finds the best score over its strided slice of scores[0:num_blocks], a block-wide reduction picks the global winner, thread 0 records the winner's index and evicts it (writes the dead sentinel so it cannot be reselected), and a barrier makes the eviction visible before the next iteration. Output positions [k_batch, k) -- or earlier, if the row runs out of selectable (finite, for largest) values -- are filled with the -1 sentinel.

Forcing (e.g. always-keep the local block) must already be baked into scores by the caller (write a large sentinel value into that block before calling); selection is purely by value, so a forced block wins a slot.

Note: One thread block per call (grid_dim = one block per row). block_dim.x must be a multiple of the warp size (required by _block_reduce_topk), and all threads must reach every iteration uniformly -- guaranteed here because k_batch depends only on k and num_blocks, identical across the block.

Parameters:

  • ​T (DType): Element dtype of the scores (float32 expected for MSA).
  • ​out_idx_type (DType): Output index dtype (int32 for MSA).
  • ​largest (Bool): Select largest (True) or smallest values.

Args: