For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Struct_msa_indexer_ragged_paged
struct Struct_msa_indexer_ragged_paged
Implemented traitsβ
Methodsβ
executeβ
static def execute[*, num_index_heads: Int, idx_head_dim: Int, block_size: Int, topk: Int, init_blocks: Int, local_blocks: Int](out_idxs: ManagedTensorSlice[Output, static_spec=out_idxs.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], input_row_offsets: ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec], prefix_lens: ManagedTensorSlice[Input, static_spec=prefix_lens.static_spec], k_blocks: ManagedTensorSlice[MutableInput, static_spec=k_blocks.static_spec], k_cache_lengths: ManagedTensorSlice[Input, static_spec=k_cache_lengths.static_spec], k_lookup_table: ManagedTensorSlice[Input, static_spec=k_lookup_table.static_spec], k_max_lengths: ManagedTensorSlice[Input, static_spec=k_max_lengths.static_spec], layer_idx: UInt32, scale: Float32, ctx: DeviceContext)
Select top-k key blocks per (index head, query token).
Dispatches to the decode kernel when kv_collection.max_seq_length == 1
(one new index-K token per sequence) and to the prefill kernel
otherwise.
Parameters:
- βnum_index_heads (
Int): Number of index (query) heads. - βidx_head_dim (
Int): Index head dimension. - βblock_size (
Int): KV block size in tokens (== page_size). - βtopk (
Int): Number of blocks to select per token. - βinit_blocks (
Int): Always-keep leading blocks (forced high score). - βlocal_blocks (
Int): Always-keep trailing/local blocks (forced score).
Args:
- βout_idxs (
ManagedTensorSlice[Output, static_spec=out_idxs.static_spec]): Output block indices[num_index_heads, num_rows, topk], int32,-1-padded (num_rows== total_q on prefill, batch on decode). - βq (
ManagedTensorSlice[Input, static_spec=q.static_spec]): Query tensor[num_rows, num_index_heads, idx_head_dim]BF16. - βinput_row_offsets (
ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec]): Ragged query offsets[batch + 1]uint32 (used on the prefill path; on decode it is[0, 1, ..., batch]). - βprefix_lens (
ManagedTensorSlice[Input, static_spec=prefix_lens.static_spec]): Per-batch cached-key count[batch]uint32 (pass the index-Kcache_lengths); used as the decodeseq_lens. - βk_blocks (
ManagedTensorSlice[MutableInput, static_spec=k_blocks.static_spec]): Index-K paged blocks[num_blocks, 1, num_layers, page_size, 1, idx_head_dim]BF16. - βk_cache_lengths (
ManagedTensorSlice[Input, static_spec=k_cache_lengths.static_spec]): Index-K cache lengths[batch]uint32. - βk_lookup_table (
ManagedTensorSlice[Input, static_spec=k_lookup_table.static_spec]): Index-K page table[batch, max_pages]uint32. - βk_max_lengths (
ManagedTensorSlice[Input, static_spec=k_max_lengths.static_spec]): Index-K max lengths[1, 2]uint32. - βlayer_idx (
UInt32): Layer index for the index-K cache. - βscale (
Float32): QK scale. - βctx (
DeviceContext): Device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!