Skip to main content

Mojo function

mla_indexer_ragged_float8_paged

mla_indexer_ragged_float8_paged[dtype: DType, KCollectionT: KVCollectionT, num_heads: Int, depth: Int, top_k: Int, mask_str: StringSlice[StaticConstantOrigin]](output_indices: TileTensor[DType.int32, output_indices.LayoutType, output_indices.origin, address_space=output_indices.address_space, linear_idx_type=output_indices.linear_idx_type, element_size=output_indices.element_size], q: TileTensor[dtype, q.LayoutType, q.origin, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], q_s: TileTensor[DType.float32, q_s.LayoutType, q_s.origin, address_space=q_s.address_space, linear_idx_type=q_s.linear_idx_type, element_size=q_s.element_size], input_row_offsets: TileTensor[DType.uint32, input_row_offsets.LayoutType, input_row_offsets.origin, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], k_collection: KCollectionT, layer_idx: UInt32, ctx: DeviceContext)

Compute FP8 indexed attention scores using paged KV cache and return top-k indices.

This function:

  1. Computes FP8 matmul between q and cached k (with scales), aggregated across heads
  2. Applies the specified mask (causal, etc.)
  3. Computes top-k indices per token (scores are summed across all heads)

Args:

  • output_indices (TileTensor): Dense output tensor for top-k indices [total_seq_len, top_k]. Invalid positions (where there are fewer than top_k valid keys due to causal masking or shorter sequences) are filled with -1.
  • q (TileTensor): Query tensor [total_seq_len, num_heads, head_dim] in FP8.
  • q_s (TileTensor): Query scales [total_seq_len, num_heads] in float32.
  • input_row_offsets (TileTensor): Ragged row offsets for queries [batch_size + 1].
  • k_collection (KCollectionT): KV collection containing cached K values and K scales. K scales are accessed via k_cache.scales (quantization_granularity=head_size).
  • layer_idx (UInt32): Layer index for retrieving cache.
  • ctx (DeviceContext): Device context.

Was this page helpful?