Mojo function
fill_invalid_topk_kernel
fill_invalid_topk_kernel[IROLayoutType: TensorLayout, iro_origin: ImmutOrigin, cache_lengths_layout: Layout, use_causal_mask: Bool](output_indices: UnsafePointer[Int32, MutAnyOrigin], input_row_offsets: TileTensor[DType.uint32, IROLayoutType, iro_origin], cache_lengths: LayoutTensor[DType.uint32, cache_lengths_layout, ImmutAnyOrigin], total_seq_len: Int, top_k: Int, effective_k: Int)
Fill invalid positions with -1 in topk output.
topk_gpu has already written valid indices to positions [0, effective_k) in output_indices (which has top_k stride). This kernel fills positions that should be -1:
- Positions [effective_k, top_k) when top_k > max_num_keys
- Positions where k_idx >= num_keys for that token
- Positions where the index VALUE >= num_keys (topk selected an invalid key)
Output shape: [total_seq_len, top_k].
With causal masking, each token can only see keys up to its position: num_keys = cache_len + local_seq_idx + 1 Without causal masking, each token can see all keys in the batch: num_keys = cache_len + seq_len
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!