Mojo function
mla_indexer_ragged_float8_paged
mla_indexer_ragged_float8_paged[dtype: DType, KCollectionT: KVCollectionT, num_heads: Int, depth: Int, top_k: Int, mask_str: StringSlice[StaticConstantOrigin]](output_indices: TileTensor[DType.int32, output_indices.LayoutType, output_indices.origin, address_space=output_indices.address_space, linear_idx_type=output_indices.linear_idx_type, element_size=output_indices.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], q_s: TileTensor[DType.float32, q_s.LayoutType, q_s.origin, address_space=q_s.address_space, linear_idx_type=q_s.linear_idx_type, element_size=q_s.element_size], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], k_collection: KCollectionT, layer_idx: UInt32, ctx: DeviceContext)
Compute FP8 indexed attention scores using paged KV cache and return top-k indices.
This function:
- Computes FP8 matmul between q and cached k (with scales), aggregated across heads
- Applies the specified mask (causal, etc.)
- Computes top-k indices per token (scores are summed across all heads)
Args:
- output_indices (
TileTensor): Dense output tensor for top-k indices [total_seq_len, top_k]. Invalid positions (where there are fewer than top_k valid keys due to causal masking or shorter sequences) are filled with -1. - q (
TileTensor): Query tensor [total_seq_len, num_heads, head_dim] in FP8. - q_s (
TileTensor): Query scales [total_seq_len, num_heads] in float32. - input_row_offsets (
TileTensor): Ragged row offsets for queries [batch_size + 1]. - k_collection (
KCollectionT): KV collection containing cached K values and K scales. K scales are accessed via k_cache.scales (quantization_granularity=head_size). - layer_idx (
UInt32): Layer index for retrieving cache. - ctx (
DeviceContext): Device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!