IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MlaPrefillV2Core

struct MlaPrefillV2Core[config: MlaConfigV2]

8-warp MLA forward kernel parameterized by MlaConfigV2.

Sibling of MhaPrefillV2. Each block runs config.num_warps wave64 warps that share K_nope / K_rope / V SMEM via cooperative DMA. Warp w owns Q rows [w * q_block_size, (w + 1) * q_block_size) of the block's stripe and carries its own register-resident attention state.

Parameters​

  • ​config (MlaConfigV2): Shape configuration (MlaConfigV2).

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

CACHE_DEPTH​

comptime CACHE_DEPTH = config.cache_depth

Latent K cache row width. For DeepSeek-V3 MLA: 576. The gap between D_NOPE (128) and ROPE_CACHE_OFFSET (512) is reserved / unused but counted in the cache stride. Matches the existing BF16 MLA path at mla_prefill.mojo:54.

D_NOPE​

comptime D_NOPE = config.depth

Alias for DEPTH exposing the nope-segment semantic name. D_NOPE == DEPTH == d_pv β€” DeepSeek-V3 MLA does not RoPE V.

D_QK​

comptime D_QK = config.d_qk

Q / K depth (d_nope + d_rope). For DeepSeek-V3 MLA: 192.

D_ROPE​

comptime D_ROPE = config.d_rope

RoPE-applied segment depth on Q and K. For DeepSeek-V3 MLA: 64.

DEPTH​

comptime DEPTH = config.depth

V / O head depth (d_pv). For DeepSeek-V3 MLA: 128. Identical semantics to MhaPrefillV2.DEPTH β€” MhaMmaOp is specialized on this via config.mha() so the V and PV machinery shares verbatim.

K_NOPE_LAYOUT​

comptime K_NOPE_LAYOUT = MlaMmaOp[config.dtype, config.mha()].K_LAYOUT

K_nope register tile, identical to the MHA path's K register tile at d=d_nope. Re-exposed from _MmaOp.K_LAYOUT.

K_ROPE_LAYOUT​

comptime K_ROPE_LAYOUT = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(32)), (config // Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(16)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(32))) // Int(64))]()

K_rope register tile at d_rope. Loaded by MhaMmaOp.load_K verbatim; col loop = width 1 (FP8) / 4 (BF16). SMEM source smem_layout_k_rope.

k_swizzle​

comptime k_swizzle = Optional(Swizzle(Int(1), Int(0), Int(4)))

k_swizzle2​

comptime k_swizzle2 = Optional(Swizzle(Int(1), Int(1), Int(4)))

KTileLoader​

comptime KTileLoader = SubTileLoaderLDS[config.dtype, MlaPrefillV2Core[config].k_swizzle, MlaPrefillV2Core[config].k_swizzle2]

KV_BLOCK​

comptime KV_BLOCK = config.kv_block

NUM_HEADS​

comptime NUM_HEADS = config.num_heads

NUM_KV_HEADS​

comptime NUM_KV_HEADS = config.num_kv_heads

NUM_THREADS​

comptime NUM_THREADS = (config * Int(64))

NUM_WARPS​

comptime NUM_WARPS = config.num_warps

prescale_q​

comptime prescale_q = not config.dtype.is_float8().__bool__()

Q prescale at load time. True for BF16 (single FP32 multiply per Q element at load); False for FP8 (avoids FP32β†’FP8 precision loss β€” post-QK scale is applied on att_block instead).

Q_BLOCK_SIZE​

comptime Q_BLOCK_SIZE = config.q_block_size

Q_LAYOUT_MLA​

comptime Q_LAYOUT_MLA = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(32)), (config // Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(16)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(32))) // Int(64))]()

Q register tile at d_qk. For FP8 KV_BLOCK=64 d_qk=192: shape [1, 3, 32] (1 row base Γ— 3 col base Γ— 32 FP8 elts/lane). For BF16: [1, 12, 8].

ROPE_CACHE_OFFSET​

comptime ROPE_CACHE_OFFSET = config.rope_cache_offset

Column offset of k_rope within the latent cache row. For DeepSeek-V3 MLA: 512 (k_nope at [:, :128], gap, k_rope at [:, 512:576]).

VTileLoader​

comptime VTileLoader = SubTileLoaderLDS_st_8x32[config.dtype, config.kv_block, config.depth, Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(32), Int((mul config.num_warps, 64))]

Methods​

load_q​

static def load_q[layout: TensorLayout](q_warp_2d: TileTensor[config.dtype, layout, Storage=q_warp_2d.Storage, address_space=q_warp_2d.address_space, linear_idx_type=q_warp_2d.linear_idx_type, element_size=q_warp_2d.element_size]) -> TileTensor[config.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]

Loads the warp's Q sub-tile at d_qk from gmem into the row_l register tile. Mirrors MhaPrefillV2.load_q but iterates _NUM_Q_K_TILES = D_QK // MMA_K K-dim base tiles instead of DEPTH // MMA_K.

For FP8 at d_qk=192 (MMA_K=64) the loop runs 3 times: two iterations cover the nope half (col 0..127) and one covers the rope half (col 128..191). Q is stored in gmem as q_nope βˆ₯ q_rope contiguously, so the underlying buffer_load_lds reads are uniform across the 3 tiles β€” no nope/rope branching here.

Returns:

TileTensor[config.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]