Mojo function
mla_indexer_ragged_float8_paged
mla_indexer_ragged_float8_paged[dtype: DType, KCollectionT: KVCollectionT, num_heads: Int, depth: Int, top_k: Int, mask_str: StringSlice[StaticConstantOrigin]](output_indices: TileTensor[DType.int32, address_space=output_indices.address_space, linear_idx_type=output_indices.linear_idx_type, element_size=output_indices.element_size], q: TileTensor[dtype, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], q_s: TileTensor[DType.float32, address_space=q_s.address_space, linear_idx_type=q_s.linear_idx_type, element_size=q_s.element_size], input_row_offsets: TileTensor[DType.uint32, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], k_collection: KCollectionT, layer_idx: UInt32, ctx: DeviceContext)
Compute FP8 indexed attention scores using paged KV cache and return top-k indices.
This function:
- Computes FP8 matmul between q and cached k (with scales), aggregated across heads
- Applies the specified mask (causal, etc.)
- Computes top-k indices per token (scores are summed across all heads)
Args:
- βoutput_indices (
TileTensor[DType.int32, address_space=output_indices.address_space, linear_idx_type=output_indices.linear_idx_type, element_size=output_indices.element_size]): Dense output tensor for top-k indices [total_seq_len, top_k]. Invalid positions (where there are fewer than top_k valid keys due to causal masking or shorter sequences) are filled with -1. - βq (
TileTensor[dtype, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size]): Query tensor [total_seq_len, num_heads, head_dim] in FP8. - βq_s (
TileTensor[DType.float32, address_space=q_s.address_space, linear_idx_type=q_s.linear_idx_type, element_size=q_s.element_size]): Query scales [total_seq_len, num_heads] in float32. - βinput_row_offsets (
TileTensor[DType.uint32, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size]): Ragged row offsets for queries [batch_size + 1]. - βk_collection (
KCollectionT): KV collection containing cached K values and K scales. K scales are accessed via k_cache.scales (quantization_granularity=head_size). - βlayer_idx (
UInt32): Layer index for retrieving cache. - βctx (
DeviceContext): Device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!