Skip to main content

Mojo function

mla_indexer_ragged_float8_paged

mla_indexer_ragged_float8_paged[dtype: DType, KCollectionT: KVCollectionT, num_heads: Int, depth: Int, top_k: Int, mask_str: StringSlice[StaticConstantOrigin]](output_indices: TileTensor[DType.int32, address_space=output_indices.address_space, linear_idx_type=output_indices.linear_idx_type, element_size=output_indices.element_size], q: TileTensor[dtype, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], q_s: TileTensor[DType.float32, address_space=q_s.address_space, linear_idx_type=q_s.linear_idx_type, element_size=q_s.element_size], input_row_offsets: TileTensor[DType.uint32, address_space=input_row_offsets.address_space, linear_idx_type=input_row_offsets.linear_idx_type, element_size=input_row_offsets.element_size], k_collection: KCollectionT, layer_idx: UInt32, ctx: DeviceContext)

Compute FP8 indexed attention scores using paged KV cache and return top-k indices.

This function:

  1. Computes FP8 matmul between q and cached k (with scales), aggregated across heads
  2. Applies the specified mask (causal, etc.)
  3. Computes top-k indices per token (scores are summed across all heads)

Args: