Skip to main content

Mojo function

fill_invalid_topk_kernel

fill_invalid_topk_kernel[IROLayoutType: TensorLayout, iro_origin: ImmutOrigin, cache_lengths_layout: Layout, use_causal_mask: Bool](output_indices: UnsafePointer[Int32, MutAnyOrigin], input_row_offsets: TileTensor[DType.uint32, IROLayoutType, iro_origin], cache_lengths: LayoutTensor[DType.uint32, cache_lengths_layout, ImmutAnyOrigin], total_seq_len: Int, top_k: Int, effective_k: Int)

Fill invalid positions with -1 in topk output.

topk_gpu has already written valid indices to positions [0, effective_k) in output_indices (which has top_k stride). This kernel fills positions that should be -1:

  • Positions [effective_k, top_k) when top_k > max_num_keys
  • Positions where k_idx >= num_keys for that token
  • Positions where the index VALUE >= num_keys (topk selected an invalid key)

Output shape: [total_seq_len, top_k].

With causal masking, each token can only see keys up to its position: num_keys = cache_len + local_seq_idx + 1 Without causal masking, each token can see all keys in the batch: num_keys = cache_len + seq_len

Was this page helpful?