For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
block_select_topk
def block_select_topk[T: DType, out_idx_type: DType, largest: Bool = True](scores: UnsafePointer[Scalar[T], MutAnyOrigin], num_blocks: Int, k: Int, out_idxs: UnsafePointer[Scalar[out_idx_type], MutAnyOrigin])
Select the top-k block indices from one row of block scores.
Cooperative across the whole thread block: on each of k_batch = min(k, num_blocks) iterations, every thread finds the best score over its strided
slice of scores[0:num_blocks], a block-wide reduction picks the global
winner, thread 0 records the winner's index and evicts it (writes the dead
sentinel so it cannot be reselected), and a barrier makes the eviction
visible before the next iteration. Output positions [k_batch, k) -- or
earlier, if the row runs out of selectable (finite, for largest) values --
are filled with the -1 sentinel.
Forcing (e.g. always-keep the local block) must already be baked into
scores by the caller (write a large sentinel value into that block before
calling); selection is purely by value, so a forced block wins a slot.
Note:
One thread block per call (grid_dim = one block per row). block_dim.x
must be a multiple of the warp size (required by _block_reduce_topk),
and all threads must reach every iteration uniformly -- guaranteed here
because k_batch depends only on k and num_blocks, identical across
the block.
Parameters:
- βT (
DType): Element dtype of the scores (float32 expected for MSA). - βout_idx_type (
DType): Output index dtype (int32 for MSA). - βlargest (
Bool): Select largest (True) or smallest values.
Args:
- βscores (
UnsafePointer[Scalar[T], MutAnyOrigin]): Pointer to this row's block scores, lengthnum_blocks. Mutated in place during extraction (the caller must treat it as scratch). - βnum_blocks (
Int): Number of valid block scores in the row. - βk (
Int): Number of indices to emit (output length). - βout_idxs (
UnsafePointer[Scalar[out_idx_type], MutAnyOrigin]): Pointer to this row's output indices, lengthk.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!