For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MLAIndexerRaggedFloat8Paged
struct MLAIndexerRaggedFloat8Paged
Implemented traitsβ
AnyType,
ImplicitlyDestructible
Methodsβ
executeβ
static def execute[*, num_heads: Int, depth: Int, k: Int, quantization_granularity: Int, mask_str: StringSlice[StaticConstantOrigin]](output_indices: ManagedTensorSlice[Output, static_spec=output_indices.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], qs: ManagedTensorSlice[Input, static_spec=qs.static_spec], input_row_offsets: ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec], k_blocks: ManagedTensorSlice[MutableInput, static_spec=k_blocks.static_spec], k_cache_lengths: ManagedTensorSlice[Input, static_spec=k_cache_lengths.static_spec], k_lookup_table: ManagedTensorSlice[Input, static_spec=k_lookup_table.static_spec], k_max_lengths: ManagedTensorSlice[Input, static_spec=k_max_lengths.static_spec], k_scales: ManagedTensorSlice[MutableInput, static_spec=k_scales.static_spec], layer_idx: UInt32, ctx: DeviceContext)
Compute FP8 attention scores and return top-k key indices per token.
This kernel is designed for Multi-head Latent Attention (MLA) architectures. It computes FP8 matmul between queries and cached keys (with scales), applies masking, and returns the indices of the top-k highest-scoring keys per token. Scores are aggregated (summed) across all attention heads.
Parameters:
- βnum_heads (
Int): Number of query attention heads (must be 128). - βdepth (
Int): Head dimension (must be 128). - βk (
Int): Number of top indices to return per token. - βquantization_granularity (
Int): Quantization granularity for the K cache. - βmask_str (
StringSlice[StaticConstantOrigin]): Mask type - either MaskName.NULL (no mask) or MaskName.CAUSAL.
Args:
- βoutput_indices (
ManagedTensorSlice[Output, static_spec=output_indices.static_spec]): Output tensor [total_seq_len, top_k] containing top-k key indices per token. Invalid positions (where there are fewer than top_k valid keys) are filled with -1. - βq (
ManagedTensorSlice[Input, static_spec=q.static_spec]): Query tensor [total_seq_len, num_heads, depth] in FP8. - βqs (
ManagedTensorSlice[Input, static_spec=qs.static_spec]): Query scales [total_seq_len, num_heads] in float32. - βinput_row_offsets (
ManagedTensorSlice[Input, static_spec=input_row_offsets.static_spec]): Ragged row offsets [batch_size + 1] for queries. - βk_blocks (
ManagedTensorSlice[MutableInput, static_spec=k_blocks.static_spec]): Paged K cache blocks [num_blocks, 1, num_layers, page_size, num_heads, head_size] in FP8. - βk_cache_lengths (
ManagedTensorSlice[Input, static_spec=k_cache_lengths.static_spec]): Cache lengths [batch_size] - number of cached tokens per sequence. - βk_lookup_table (
ManagedTensorSlice[Input, static_spec=k_lookup_table.static_spec]): Page lookup table [batch_size, pages_per_seq] mapping sequence pages to block indices. - βk_max_lengths (
ManagedTensorSlice[Input, static_spec=k_max_lengths.static_spec]): Max lengths tensor [1, 2] containing [max_seq_len, max_cache_len]. - βk_scales (
ManagedTensorSlice[MutableInput, static_spec=k_scales.static_spec]): K scale blocks matching k_blocks shape with scale values. - βlayer_idx (
UInt32): Layer index for retrieving the correct cache layer. - βctx (
DeviceContext): Device context for GPU execution.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!