For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MlaPrefillV2Core
struct MlaPrefillV2Core[config: MlaConfigV2]
8-warp MLA forward kernel parameterized by MlaConfigV2.
Sibling of MhaPrefillV2. Each block runs config.num_warps
wave64 warps that share K_nope / K_rope / V SMEM via cooperative
DMA. Warp w owns Q rows [w * q_block_size, (w + 1) * q_block_size) of the block's stripe and carries its own
register-resident attention state.
Parametersβ
- βconfig (
MlaConfigV2): Shape configuration (MlaConfigV2).
Implemented traitsβ
comptime membersβ
CACHE_DEPTHβ
comptime CACHE_DEPTH = config.cache_depth
Latent K cache row width. For DeepSeek-V3 MLA: 576. The gap between D_NOPE (128) and ROPE_CACHE_OFFSET (512) is reserved / unused but counted in the cache stride. Matches the existing BF16 MLA path at mla_prefill.mojo:54.
D_NOPEβ
comptime D_NOPE = config.depth
Alias for DEPTH exposing the nope-segment semantic name. D_NOPE == DEPTH == d_pv β DeepSeek-V3 MLA does not RoPE V.
D_QKβ
comptime D_QK = config.d_qk
Q / K depth (d_nope + d_rope). For DeepSeek-V3 MLA: 192.
D_ROPEβ
comptime D_ROPE = config.d_rope
RoPE-applied segment depth on Q and K. For DeepSeek-V3 MLA: 64.
DEPTHβ
comptime DEPTH = config.depth
V / O head depth (d_pv). For DeepSeek-V3 MLA: 128. Identical semantics to MhaPrefillV2.DEPTH β MhaMmaOp is specialized on this via config.mha() so the V and PV machinery shares verbatim.
K_NOPE_LAYOUTβ
comptime K_NOPE_LAYOUT = MlaMmaOp[config.dtype, config.mha()].K_LAYOUT
K_nope register tile, identical to the MHA path's K register tile at d=d_nope. Re-exposed from _MmaOp.K_LAYOUT.
K_ROPE_LAYOUTβ
comptime K_ROPE_LAYOUT = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(32)), (config // Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(16)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(32))) // Int(64))]()
K_rope register tile at d_rope. Loaded by MhaMmaOp.load_K verbatim; col loop = width 1 (FP8) / 4 (BF16). SMEM source smem_layout_k_rope.
k_swizzleβ
comptime k_swizzle = Optional(Swizzle(Int(1), Int(0), Int(4)))
k_swizzle2β
comptime k_swizzle2 = Optional(Swizzle(Int(1), Int(1), Int(4)))
KTileLoaderβ
comptime KTileLoader = SubTileLoaderLDS[config.dtype, MlaPrefillV2Core[config].k_swizzle, MlaPrefillV2Core[config].k_swizzle2]
KV_BLOCKβ
comptime KV_BLOCK = config.kv_block
NUM_HEADSβ
comptime NUM_HEADS = config.num_heads
NUM_KV_HEADSβ
comptime NUM_KV_HEADS = config.num_kv_heads
NUM_THREADSβ
comptime NUM_THREADS = (config * Int(64))
NUM_WARPSβ
comptime NUM_WARPS = config.num_warps
prescale_qβ
comptime prescale_q = not config.dtype.is_float8().__bool__()
Q prescale at load time. True for BF16 (single FP32 multiply per Q element at load); False for FP8 (avoids FP32βFP8 precision loss β post-QK scale is applied on att_block instead).
Q_BLOCK_SIZEβ
comptime Q_BLOCK_SIZE = config.q_block_size
Q_LAYOUT_MLAβ
comptime Q_LAYOUT_MLA = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(32)), (config // Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(16)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) and config.mha().fp8_mma_k_128 else Int(32))) // Int(64))]()
Q register tile at d_qk. For FP8 KV_BLOCK=64 d_qk=192: shape [1, 3, 32] (1 row base Γ 3 col base Γ 32 FP8 elts/lane). For BF16: [1, 12, 8].
ROPE_CACHE_OFFSETβ
comptime ROPE_CACHE_OFFSET = config.rope_cache_offset
Column offset of k_rope within the latent cache row. For DeepSeek-V3 MLA: 512 (k_nope at [:, :128], gap, k_rope at [:, 512:576]).
VTileLoaderβ
comptime VTileLoader = SubTileLoaderLDS_st_8x32[config.dtype, config.kv_block, config.depth, Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> #lit.struct.extract<:!lit.struct<@nn::@attention::@gpu::@amd_structured::@mha_mma_op::@MlaConfigV2> config, "dtype">, "_mlir_value">>, 78) else Int(32), Int((mul config.num_warps, 64))]
Methodsβ
load_qβ
static def load_q[layout: TensorLayout](q_warp_2d: TileTensor[config.dtype, layout, Storage=q_warp_2d.Storage, address_space=q_warp_2d.address_space, linear_idx_type=q_warp_2d.linear_idx_type, element_size=q_warp_2d.element_size]) -> TileTensor[config.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Loads the warp's Q sub-tile at d_qk from gmem into the row_l register tile. Mirrors MhaPrefillV2.load_q but iterates _NUM_Q_K_TILES = D_QK // MMA_K K-dim base tiles instead of DEPTH // MMA_K.
For FP8 at d_qk=192 (MMA_K=64) the loop runs 3 times: two
iterations cover the nope half (col 0..127) and one covers the
rope half (col 128..191). Q is stored in gmem as q_nope β₯ q_rope contiguously, so the underlying buffer_load_lds reads
are uniform across the 3 tiles β no nope/rope branching here.
Returns:
TileTensor[config.dtype, Layout[*?, *?], MutUntrackedOrigin, address_space=AddressSpace.LOCAL]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!