Mojo function
mla_sm100_prefill_sparse
mla_sm100_prefill_sparse[output_type: DType, q_type: DType, cache_t: KVCacheT, //, num_q_heads: Int, qk_depth: Int, v_depth: Int, indices_stride: Int](output: TileTensor[output_type, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[q_type, linear_idx_type=q.linear_idx_type, element_size=q.element_size], kv_cache: cache_t, indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], topk_lengths: TileTensor[DType.uint32, linear_idx_type=topk_lengths.linear_idx_type, element_size=topk_lengths.element_size], attn_sink_ptr: UnsafePointer[Float32, ImmutAnyOrigin], scale: Float32, ctx: DeviceContext)
Sparse MLA prefill (DSv3.2 absorbed shape, BF16, SM100).
Thin wrapper around mla_prefill_sparse that builds the
MLASparseConfig from the passed dimensions so callers don't have to
reach into the kernel's config type. The kernel itself hardcodes the
DSv3.2 absorbed/latent shape (qk_depth=576, v_depth=512,
num_q_heads=128, num_kv_heads=1) and asserts on those values.
Parameters:
- βoutput_type (
DType): Output element type (must be the same width asq_type; the kernel asserts this). - βq_type (
DType): Query element type (BF16 in the supported DSv3.2 shape). - βcache_t (
KVCacheT): KV cache type (typically a paged MLA cache obtained fromkv_collection.get_key_cache(layer_idx)). - βnum_q_heads (
Int): Number of query heads (must be 128 for the DSv3.2 absorbed shape). - βqk_depth (
Int): Per-head Q/K depth (must be 576 =kv_lora_rank(512) + qk_rope_head_dim(64)). - βv_depth (
Int): Per-head V depth (must be 512 =kv_lora_rank). - βindices_stride (
Int): Per-query indices buffer stride (= the indexer'sindex_topk). Also used as the runtimeindices_strideto the kernel.
Args:
- βoutput (
TileTensor[output_type, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output tile tensor with shape[total_q_tokens, num_q_heads, v_depth]. - βq (
TileTensor[q_type, linear_idx_type=q.linear_idx_type, element_size=q.element_size]): Query tile tensor with shape[total_q_tokens, num_q_heads, qk_depth]. - βkv_cache (
cache_t): Paged MLA KV cache for the current layer. - βindices (
TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size]): Per-query gather4 indices, encoded asInt32(physical_block_id * page_size + token_offset_within_page)(reinterpreted via theuint32tile-tensor view;-1-bit-pattern sentinels are masked out by the kernel's k-valid producer). - βtopk_lengths (
TileTensor[DType.uint32, linear_idx_type=topk_lengths.linear_idx_type, element_size=topk_lengths.element_size]): Per-query effective top-k count ([total_q_tokens]). - βattn_sink_ptr (
UnsafePointer[Float32, ImmutAnyOrigin]): Optional attention sink (oneFloat32per query head). Pass a null pointer to skip the sink term in the softmax epilogue. - βscale (
Float32): Softmax scale (1 / sqrt(qk_nope_head_dim + qk_rope_head_dim) * mscale^2; for DSv3.2 with mscale=1,1 / sqrt(192)). - βctx (
DeviceContext): GPU device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!