IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MLAIndexerRaggedFloat8Paged

struct MLAIndexerRaggedFloat8Paged

Implemented traits​

AnyType, ImplicitlyDestructible

Methods​

execute​

static def execute[*, num_heads: Int, depth: Int, k: Int, quantization_granularity: Int, mask_str: StringSlice[StaticConstantOrigin]](output_indices: ManagedTensorSlice[Output, static_spec=output_indices.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], qs: ManagedTensorSlice[Input, static_spec=qs.static_spec], input_row_offsets: ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec], k_blocks: ManagedTensorSlice[MutableInput, static_spec=k_blocks.static_spec], k_cache_lengths: ManagedTensorSlice[Input, static_spec=k_cache_lengths.static_spec], k_lookup_table: ManagedTensorSlice[Input, static_spec=k_lookup_table.static_spec], k_max_lengths: ManagedTensorSlice[Input, static_spec=k_max_lengths.static_spec], k_scales: ManagedTensorSlice[MutableInput, static_spec=k_scales.static_spec], layer_idx: UInt32, ctx: DeviceContext)

Compute FP8 attention scores and return top-k key indices per token.

This kernel is designed for Multi-head Latent Attention (MLA) architectures. It computes FP8 matmul between queries and cached keys (with scales), applies masking, and returns the indices of the top-k highest-scoring keys per token. Scores are aggregated (summed) across all attention heads.

Parameters:

  • ​num_heads (Int): Number of query attention heads (must be 128).
  • ​depth (Int): Head dimension (must be 128).
  • ​k (Int): Number of top indices to return per token.
  • ​quantization_granularity (Int): Quantization granularity for the K cache.
  • ​mask_str (StringSlice[StaticConstantOrigin]): Mask type - either MaskName.NULL (no mask) or MaskName.CAUSAL.

Args:

  • ​output_indices (ManagedTensorSlice[Output, static_spec=output_indices.static_spec]): Output tensor [total_seq_len, top_k] containing top-k key indices per token. Invalid positions (where there are fewer than top_k valid keys) are filled with -1.
  • ​q (ManagedTensorSlice[Input, static_spec=q.static_spec]): Query tensor [total_seq_len, num_heads, depth] in FP8.
  • ​qs (ManagedTensorSlice[Input, static_spec=qs.static_spec]): Query scales [total_seq_len, num_heads] in float32.
  • ​input_row_offsets (ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec]): Ragged row offsets [batch_size + 1] for queries.
  • ​k_blocks (ManagedTensorSlice[MutableInput, static_spec=k_blocks.static_spec]): Paged K cache blocks [num_blocks, 1, num_layers, page_size, num_heads, head_size] in FP8.
  • ​k_cache_lengths (ManagedTensorSlice[Input, static_spec=k_cache_lengths.static_spec]): Cache lengths [batch_size] - number of cached tokens per sequence.
  • ​k_lookup_table (ManagedTensorSlice[Input, static_spec=k_lookup_table.static_spec]): Page lookup table [batch_size, pages_per_seq] mapping sequence pages to block indices.
  • ​k_max_lengths (ManagedTensorSlice[Input, static_spec=k_max_lengths.static_spec]): Max lengths tensor [1, 2] containing [max_seq_len, max_cache_len].
  • ​k_scales (ManagedTensorSlice[MutableInput, static_spec=k_scales.static_spec]): K scale blocks matching k_blocks shape with scale values.
  • ​layer_idx (UInt32): Layer index for retrieving the correct cache layer.
  • ​ctx (DeviceContext): Device context for GPU execution.