For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
persistent_topk_block
def persistent_topk_block(ctx: DeviceContext, in_scores: UnsafePointer[Float32, ImmutAnyOrigin], out_idxs: UnsafePointer[Int32, MutAnyOrigin], N: Int, K: Int, total_seq_len: Int)
Launch block-wide bitonic sort top-k for total_seq_len score rows.
Handles N β€ PERSISTENT_TOPK_MAX_N (= 2048). Call site must check this
before calling; for larger N fall back to topk_gpu.
Each row of N float32 scores is sorted descending in a single block,
writing the K highest-scoring indices (as int32) to out_idxs.
The SMEM budget is 2 * PERSISTENT_TOPK_MAX_N * 4 = 16 KB per block.
Args:
- βctx (
DeviceContext): Device context. - βin_scores (
UnsafePointer[Float32, ImmutAnyOrigin]): Flat score buffer[total_seq_len Γ N]row-major. - βout_idxs (
UnsafePointer[Int32, MutAnyOrigin]): Output buffer[total_seq_len Γ K]row-major (int32). - βN (
Int): Score columns per token (β€ PERSISTENT_TOPK_MAX_N). - βK (
Int): Top-k count per token (β€ N). - βtotal_seq_len (
Int): Number of rows (one block per row).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!