IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

persistent_topk_block

def persistent_topk_block(ctx: DeviceContext, in_scores: UnsafePointer[Float32, ImmutAnyOrigin], out_idxs: UnsafePointer[Int32, MutAnyOrigin], N: Int, K: Int, total_seq_len: Int)

Launch block-wide bitonic sort top-k for total_seq_len score rows.

Handles N ≀ PERSISTENT_TOPK_MAX_N (= 2048). Call site must check this before calling; for larger N fall back to topk_gpu.

Each row of N float32 scores is sorted descending in a single block, writing the K highest-scoring indices (as int32) to out_idxs.

The SMEM budget is 2 * PERSISTENT_TOPK_MAX_N * 4 = 16 KB per block.

Args: