IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MhaMmaOp

struct MhaMmaOp[T: DType, config: MhaConfigV2]

Namespace-style struct holding the shape constants, register-tile layouts, and SMEM→register loaders for MhaPrefillV2. All call sites go through static methods on this struct.

Supports two MFMA flavors on gfx950, comptime-selected on T:

  • BF16 → v_mfma_f32_32x32x16_bf16, MMA_K=16, 8 BF16 elts/lane.
  • FP8 e4m3 → v_mfma_scale_f32_32x32x64_f8f6f4, MMA_K=64, 32 FP8 elts/lane. MMA dispatch is automatic via gpu_mma SIMD-size overload resolution; only the per-lane fragment size differs.

load_K / load_V both comptime-branch on T.is_float8():

  • BF16 path: byte-identical to the original reference kernel — K via two-XOR swizzle + ds_read_b128, V via ds_read_tr16_b64_warp.
  • FP8 path: K reuses the same byte-level two-XOR swizzle (the sub-block is byte-equivalent at 32 rows × 64 B), V uses ds_read_tr8_b64 paired-lane transpose-loads matching TiledMmaLoader.load_v_fp8_strip.

Parameters

  • T (DType): Element data type (BF16 or FP8 e4m3).
  • config (MhaConfigV2): Shape configuration.

Implemented traits

AnyType, ImplicitlyDeletable

comptime members

ATT_BF16_FULL_LAYOUT

comptime ATT_BF16_FULL_LAYOUT = row_major[(config // Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16)), (config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]()

Full att input-dtype tile pre-cast from FP32 (indexed by subtile_idx to feed mma_PV strip-by-strip).

BF16: 4 subtiles (KV_BLOCK=64 / MMA_K=16). FP8: 1 subtile (KV_BLOCK=64 / MMA_K=64).

ATT_BF16_SUB_LAYOUT

comptime ATT_BF16_SUB_LAYOUT = row_major[Int(1), (config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]()

One PV-A subtile (MMA_K-row strip of att, input-dtype).

BF16: 16-row strip, per-lane 8 BF16. FP8: 64-row strip, per-lane 32 FP8.

ATT_LAYOUT

comptime ATT_LAYOUT = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (Int((mul Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]()

Attention block (QK output, col_l rt_32x32 FP32).

DEPTH

comptime DEPTH = config.depth

FP8_MMA_K_128

comptime FP8_MMA_K_128 = config.fp8_mma_k_128

FRAG_ELTS

comptime FRAG_ELTS = (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))

Input elements per lane per MFMA base tile.

8 for BF16 (32×32×16). 32 for FP8 32×32×64. 32 for FP8 16×16×128 (the smaller M-dim is offset by the larger K-dim — total M×K÷64 is identical).

K_LAYOUT

comptime K_LAYOUT = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (config // Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]()

K register tile (whole K, pre-loaded across cluster boundaries).

K_SUB_COLS

comptime K_SUB_COLS = Int(64) if T.is_float8() else Int(32)

K SMEM sub-block cols. 32 BF16 elts = 64 B/row; 64 FP8 elts = 64 B/row. Byte-equivalent — same SMEM geometry for both FP8 MMA shapes (32×32×64 and 16×16×128) because the parent allocation in mha_prefill_v2.mojo is unchanged; only the consumer-side lane partition differs.

For FP8 32×32×64: each sub-block holds exactly one 32×64 base tile.

For FP8 16×16×128: each sub-block holds two 16-row sub-strips (top + bottom 16 rows), each spanning 64 B. A single 16×128 MFMA base tile needs 2 of these K-direction sub-blocks (to cover 128 K elements at 64 cols/sub-block). The consumer-side lane partition is 16-lanes-per-row × 4 K-groups.

K_SUB_ROWS

comptime K_SUB_ROWS = 32

K SMEM sub-block rows. Same for BF16 and FP8 (32×32×64 path); the FP8 16×16×128 path keeps the same parent SMEM geometry so the cooperative DMA producer (SubTileLoaderLDS) doesn't change — the difference is purely on the consumer-side lane partition that load_K performs.

BF16: 32×32 BF16 elts = 32×64 B. FP8 (any shape): 32×64 FP8 elts = 32×64 B (byte-equivalent so the byte-positional swizzle reuses across dtypes).

KV_BLOCK

comptime KV_BLOCK = config.kv_block

MMA_K

comptime MMA_K = Int(128) if T.is_float8() and MhaMmaOp[T, config].FP8_MMA_K_128 else Int(64) if T.is_float8() else Int(16)

MMA_M

comptime MMA_M = Int(16) if T.is_float8() and MhaMmaOp[T, config].FP8_MMA_K_128 else Int(32)

MMA_N

comptime MMA_N = Int(16) if T.is_float8() and MhaMmaOp[T, config].FP8_MMA_K_128 else Int(32)

O_LAYOUT

comptime O_LAYOUT = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (Int((mul Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]()

Output accumulator (col_l rt_32x32 FP32).

O_T_LAYOUT

comptime O_T_LAYOUT = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (Int((mul Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]()

Output transpose (row_l view of the same storage as O_LAYOUT).

PV_A_FRAG_ELTS

comptime PV_A_FRAG_ELTS = (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))

Per-lane PV-A fragment width = MMA_K * MMA_N / 64.

8 for BF16 (1632/64). 32 for FP8 32×32×64 (6432/64). 32 for FP8 16×16×128 (128*16/64) — same as 32×32×64.

Folds to a literal Int at type-check time because MMA_K and MMA_N are both comptime constants.

Q_BLOCK_SIZE

comptime Q_BLOCK_SIZE = config.q_block_size

Q_LAYOUT

comptime Q_LAYOUT = row_major[(config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (config // Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]()

Q register tile.

ROWL_HALF_LANES

comptime ROWL_HALF_LANES = Int(16) if (Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32) == Int(16)) else Int(32)

Lanes per row in the K MFMA A-side lane partition. Equal to MMA_M for the standard partition.

ROWL_STRIDE

comptime ROWL_STRIDE = (Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16) // Int(2)) if (Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32) == Int(32)) else (Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16) // Int(4))

Per-lane K-direction fragment width.

For 32×32×{16,64}: MMA_K // 2. Half-warp split packs 2 K-strips per base tile (8 for BF16 / 32 for FP8 32×32×64).

For 16×16×128: MMA_K // 4 = 32. 4 K-groups per base tile; each lane-group of 16 lanes owns one K-group of 32 FP8 elements.

V_LAYOUT

comptime V_LAYOUT = row_major[(config // Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16)), (config // Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32)), (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]()

V register tile.

V_SUB_COLS

comptime V_SUB_COLS = Int(64) if T.is_float8() else Int(32)

V SMEM sub-block cols. 32 BF16 elts = 64 B/row; 64 FP8 elts = 64 B/row. Byte-equivalent — same SMEM geometry for both FP8 MMA shapes.

V_SUB_ROWS

comptime V_SUB_ROWS = 8

V SMEM sub-block rows. Matches the BF16 reference st_8x32_s row count; FP8 V uses paired-lane ds_read_tr8_b64 (32×32×64 path) or a per-lane scalar gather (16×16×128 path) — both treat the V SMEM slab as a contiguous BN×depth block. The sub-block dims still drive the parent SMEM layout in mha_prefill_v2.mojo; keep V_SUB_ROWS=8 so the cooperative DMA producer geometry (SubTileLoaderLDS_st_8x32) is dtype/shape-agnostic.

Methods

load_K

static def load_K[layout_dst: TensorLayout, layout_src: TensorLayout, //](mut dst: TileTensor[T, layout_dst, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], src: TileTensor[T, layout_src, MutAnyOrigin, address_space=AddressSpace.SHARED])

Loads the whole (KV_BLOCK, DEPTH) K tile from SMEM into the row_l register tile (32×MMA_K base tiles), unswizzling on the way.

Caller must declare K SMEM with shape row_major[KV_BLOCK * (DEPTH / K_SUB_COLS), K_SUB_COLS] so the sub-block id linearizes via .tile[K_SUB_ROWS, K_SUB_COLS](id, 0).

BF16 path: each 32×32 sub-block holds two 32×16 base tiles (col parity 0 and 1). 8 BF16 elts/lane per base tile → one ds_read_b128. FP8 path: each 32×64 sub-block holds exactly one 32×64 base tile (the col-parity loop collapses). 32 FP8 elts/lane = 32 B = two 16-B _load_from_lds reads, joined.

load_V

static def load_V[layout_dst: TensorLayout, layout_src: TensorLayout, //](mut dst: TileTensor[T, layout_dst, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], src: TileTensor[T, layout_src, MutAnyOrigin, address_space=AddressSpace.SHARED])

Loads the whole V tile from SMEM into the col_l register tile.

BF16 path: ds_read_tr16_b64_warp over two 8×32 sub-blocks (top + bot) joined into an 8-elt fragment. V_LAYOUT = row_major[KV_BLOCK/16, DEPTH/32, 8].

FP8 path: paired-lane ds_read_tr8_b64 — 4 reads per (i, j) joined into a 32-elt fragment. V_LAYOUT = row_major[KV_BLOCK/64, DEPTH/32, 32]. Per-lane addressing mirrors TiledMmaLoader.load_v_fp8_strip, but the SMEM linearization uses the sub-tile-major layout the st_8x32 DMA writes — not a contiguous BN×depth slab. See the sub-tile linearization comment below.

Caller must declare V SMEM with shape row_major[KV_BLOCK * (DEPTH / V_SUB_COLS), V_SUB_COLS]. The DMA producer SubTileLoaderLDS_st_8x32 writes sub-tiles (V_SUB_ROWS × V_SUB_COLS) in row-major-by-sub-tile order: sub_id = sub_row * (DEPTH / V_SUB_COLS) + sub_col, where sub_row = key // V_SUB_ROWS and sub_col = depth // V_SUB_COLS. The FP8 base tile (MMA_K=64 keys, MMA_N=32 depth cols) spans 8 sub-tile rows (V_SUB_ROWS=8) in K and half a sub-tile col in N (V_SUB_COLS=64 vs MMA_N=32). Loader must compute the correct sub-tile id from the per-lane key — assuming contiguous depth-blocks catastrophically corrupts the second (and later) depth block at d >= V_SUB_COLS.

mma_QK

static def mma_QK[T_att: DType, layout_att: TensorLayout, layout_k: TensorLayout, layout_q: TensorLayout, //](mut att: TileTensor[T_att, layout_att, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut k: TileTensor[T, layout_k, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut q: TileTensor[T, layout_q, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

QK MFMA: att += k @ q^T. K is A (M-outer), Q is B (N-outer).

For each output base tile (n, m): att[n, m] += sum_k k[n, k] * q[m, k].

mma_PV

static def mma_PV[T_o: DType, layout_o: TensorLayout, layout_v: TensorLayout, layout_p: TensorLayout, //](mut o: TileTensor[T_o, layout_o, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut v: TileTensor[T, layout_v, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut p: TileTensor[T, layout_p, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

PV MFMA: o += v^T @ p. V is A (K-outer), P is B (K-outer, JIT-cast to Self.T from att_block).

For each output base tile (n, m): o[n, m] += sum_k v[k, n] * p[k, m].

exp2_inplace_range

static def exp2_inplace_range[T_att: DType, layout: TensorLayout, //, start: Int, end: Int](mut tile: TileTensor[T_att, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

In-place exp2 over a base-tile-aligned per-lane slice tile[start:end]. start and end must be multiples of the fragment width so the slice maps to whole base tiles. Used to split first / second half of the softmax exp2 across PV MFMAs.

T_att is FP32 in the BF16 attention path and BF16 in the FP8 attention path's BF16 softmax. math_exp2 works with either; for BF16 it lowers to v_cvt_f32_bf16 + v_exp_f32 + v_cvt_pkrtz_bf16_f32 (no packed transcendental on gfx950 — v_exp_f32 is scalar regardless of input dtype).