IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

fill_invalid_topk_kernel

def fill_invalid_topk_kernel[IROLayoutType: TensorLayout, iro_origin: ImmutOrigin, cache_lengths_layout: TensorLayout, use_causal_mask: Bool](output_indices: UnsafePointer[Int32, MutAnyOrigin], input_row_offsets: TileTensor[DType.uint32, IROLayoutType, iro_origin], cache_lengths: TileTensor[DType.uint32, cache_lengths_layout, ImmutAnyOrigin], total_seq_len: Int, top_k: Int, effective_k: Int)

Fill invalid positions with -1 in topk output.

topk_gpu has already written valid indices to positions [0, effective_k) in output_indices (which has top_k stride). This kernel fills positions that should be -1:

  • Positions [effective_k, top_k) when top_k > max_num_keys
  • Positions where k_idx >= num_keys for that token
  • Positions where the index VALUE >= num_keys (topk selected an invalid key)

Output shape: [total_seq_len, top_k].

With causal masking, each token can only see keys up to its position: num_keys = cache_len + local_seq_idx + 1 Without causal masking, each token can see all keys in the batch: num_keys = cache_len + seq_len