Skip to main content

Mojo function

mla_sm100_prefill_sparse

mla_sm100_prefill_sparse[output_type: DType, q_type: DType, cache_t: KVCacheT, //, num_q_heads: Int, qk_depth: Int, v_depth: Int, indices_stride: Int](output: TileTensor[output_type, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[q_type, linear_idx_type=q.linear_idx_type, element_size=q.element_size], kv_cache: cache_t, indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], topk_lengths: TileTensor[DType.uint32, linear_idx_type=topk_lengths.linear_idx_type, element_size=topk_lengths.element_size], attn_sink_ptr: UnsafePointer[Float32, ImmutAnyOrigin], scale: Float32, ctx: DeviceContext)

Sparse MLA prefill (DSv3.2 absorbed shape, BF16, SM100).

Thin wrapper around mla_prefill_sparse that builds the MLASparseConfig from the passed dimensions so callers don't have to reach into the kernel's config type. The kernel itself hardcodes the DSv3.2 absorbed/latent shape (qk_depth=576, v_depth=512, num_q_heads=128, num_kv_heads=1) and asserts on those values.

Parameters:

  • ​output_type (DType): Output element type (must be the same width as q_type; the kernel asserts this).
  • ​q_type (DType): Query element type (BF16 in the supported DSv3.2 shape).
  • ​cache_t (KVCacheT): KV cache type (typically a paged MLA cache obtained from kv_collection.get_key_cache(layer_idx)).
  • ​num_q_heads (Int): Number of query heads (must be 128 for the DSv3.2 absorbed shape).
  • ​qk_depth (Int): Per-head Q/K depth (must be 576 = kv_lora_rank(512) + qk_rope_head_dim(64)).
  • ​v_depth (Int): Per-head V depth (must be 512 = kv_lora_rank).
  • ​indices_stride (Int): Per-query indices buffer stride (= the indexer's index_topk). Also used as the runtime indices_stride to the kernel.

Args: