IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MlaMmaOp

struct MlaMmaOp[T: DType, config: MhaConfigV2]

MLA MMA op for MlaPrefillV2 / MlaPrefillV2Core.

Extends MhaMmaOp (re-exporting its shape constants / register-tile layouts and delegating _swizzle_K_sub / mma_QK / mma_PV / exp2_inplace_range) with the MLA FP8 32x32x64 fragment loaders β€” load_K_frag, precompute_v_lane_base, load_V_from_lane_base, load_V_frag β€” that stream K/V through a rotating in-register band (the reference never materializes the whole V tile).

The split isolates these FP8 extensions from MHA: the MhaPrefillV2 BF16 prefill path uses MhaMmaOp directly, so an edit to MLA's FP8 fragment loaders can no longer perturb the MHA kernel's BF16 load_K / load_V schedule (interleaving FP8 branches into the shared whole-tile loaders opens a K/V DMA-vs-ds_read vmcnt race in MhaPrefillV2).

MLA never calls the whole-tile load_K / load_V; it consumes K/V fragment-at-a-time via the *_frag loaders. The whole-tile loaders therefore stay on MhaMmaOp only.

Parameters​

  • ​T (DType): Element data type (FP8 e4m3 for the MLA prefill path).
  • ​config (MhaConfigV2): Shape configuration (MhaConfigV2, via MlaConfigV2.mha()).

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

ATT_BF16_FULL_LAYOUT​

comptime ATT_BF16_FULL_LAYOUT = MhaMmaOp[T, config].ATT_BF16_FULL_LAYOUT

ATT_BF16_SUB_LAYOUT​

comptime ATT_BF16_SUB_LAYOUT = MhaMmaOp[T, config].ATT_BF16_SUB_LAYOUT

ATT_LAYOUT​

comptime ATT_LAYOUT = MhaMmaOp[T, config].ATT_LAYOUT

DEPTH​

comptime DEPTH = MhaMmaOp[T, config].DEPTH

FP8_MMA_K_128​

comptime FP8_MMA_K_128 = MhaMmaOp[T, config].FP8_MMA_K_128

FRAG_ELTS​

comptime FRAG_ELTS = MhaMmaOp[T, config].FRAG_ELTS

K_LAYOUT​

comptime K_LAYOUT = MhaMmaOp[T, config].K_LAYOUT

K_SUB_COLS​

comptime K_SUB_COLS = MhaMmaOp[T, config].K_SUB_COLS

K_SUB_ROWS​

comptime K_SUB_ROWS = MhaMmaOp[T, config].K_SUB_ROWS

KV_BLOCK​

comptime KV_BLOCK = MhaMmaOp[T, config].KV_BLOCK

MMA_K​

comptime MMA_K = MhaMmaOp[T, config].MMA_K

MMA_M​

comptime MMA_M = MhaMmaOp[T, config].MMA_M

MMA_N​

comptime MMA_N = MhaMmaOp[T, config].MMA_N

O_LAYOUT​

comptime O_LAYOUT = MhaMmaOp[T, config].O_LAYOUT

O_T_LAYOUT​

comptime O_T_LAYOUT = MhaMmaOp[T, config].O_T_LAYOUT

PV_A_FRAG_ELTS​

comptime PV_A_FRAG_ELTS = MhaMmaOp[T, config].PV_A_FRAG_ELTS

Q_BLOCK_SIZE​

comptime Q_BLOCK_SIZE = MhaMmaOp[T, config].Q_BLOCK_SIZE

Q_LAYOUT​

comptime Q_LAYOUT = MhaMmaOp[T, config].Q_LAYOUT

ROWL_HALF_LANES​

comptime ROWL_HALF_LANES = MhaMmaOp[T, config].ROWL_HALF_LANES

ROWL_STRIDE​

comptime ROWL_STRIDE = MhaMmaOp[T, config].ROWL_STRIDE

V_LAYOUT​

comptime V_LAYOUT = MhaMmaOp[T, config].V_LAYOUT

V_SUB_COLS​

comptime V_SUB_COLS = MhaMmaOp[T, config].V_SUB_COLS

V_SUB_ROWS​

comptime V_SUB_ROWS = MhaMmaOp[T, config].V_SUB_ROWS

Methods​

mma_QK​

static def mma_QK[T_att: DType, layout_att: TensorLayout, layout_k: TensorLayout, layout_q: TensorLayout, //](mut att: TileTensor[T_att, layout_att, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut k: TileTensor[T, layout_k, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut q: TileTensor[T, layout_q, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

QK MFMA β€” delegates to MhaMmaOp.mma_QK (body identical).

mma_PV​

static def mma_PV[T_o: DType, layout_o: TensorLayout, layout_v: TensorLayout, layout_p: TensorLayout, //](mut o: TileTensor[T_o, layout_o, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut v: TileTensor[T, layout_v, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], mut p: TileTensor[T, layout_p, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

PV MFMA β€” delegates to MhaMmaOp.mma_PV (body identical).

exp2_inplace_range​

static def exp2_inplace_range[T_att: DType, layout: TensorLayout, //, start: Int, end: Int](mut tile: TileTensor[T_att, layout, MutUntrackedOrigin, address_space=AddressSpace.LOCAL])

exp2 over a slice β€” delegates to MhaMmaOp.exp2_inplace_range.

load_K_frag​

static def load_K_frag[sub_id: Int](src: TileTensor[T, MutAnyOrigin, address_space=AddressSpace.SHARED]) -> SIMD[T, (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]

Loads ONE K MFMA fragment (sub_id) from the K SMEM sub-view src and returns it as a SIMD value β€” the single-fragment factoring of the FP8 32x32x64 K-load inner loop.

Returning a SIMD value (rather than writing a register tile) lets the caller stream fragments through a rotating in-register band where each slot is a plain SSA value β€” sidestepping the strided-register-sub-view write that lands on the wrong VGPRs when a sub-tile of a larger reg_alloc is the destination at a non-zero offset. The fragment feeds gpu_mma directly (the QK A-operand).

sub_id is the K SMEM sub-block index; it enters only as a comptime ds_read offset: immediate. The two per-lane swizzled bases are loop-invariant in src.ptr, so across a caller's unrolled per-fragment stream LLVM CSEs them to a single base register pair (the reference's v226 single-base K pattern).

Constraints:

FP8 32x32x64 path only (MMA_K == 64).

Returns:

SIMD[T, (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]

precompute_v_lane_base​

static def precompute_v_lane_base[origin: Origin[mut=origin.mut], //, v_full_v227: Bool = False, v227_layout: Bool = False](v_slot_ptr: UnsafePointer[Scalar[T], origin, address_space=AddressSpace.SHARED]) -> UnsafePointer[Scalar[T], origin, address_space=AddressSpace.SHARED]

Computes the per-lane V LDS base pointer for the FP8 32x32x64 path ONCE per V SMEM slot (caller passes v_smem_<stage>.ptr). Hoisting this base out of the per-fragment readout collapses the per-call lane offset across all invocations from the same slot into ONE shared materialization β€” mirroring the reference's v227 base carried across the entire main loop. Caller threads the returned pointer into load_V_frag. Origin is propagated from the input slot so the returned pointer carries the slot's lifetime/mutability annotation.

Parameters:

  • ​v_full_v227 (Bool): Reference v227 V adapter per-lane READ base (Bool). Default False β†’ byte-identical. When True, replaces the per-lane base ENTIRELY with the reference's exact v227 formula (the WHOLE read map), a per-tr8-cycle bank-quadrant bijection that eliminates the V-transpose LDS bank conflict. This is HALF of the adapter R β€” the caller MUST also pass v_full_v227=True to load_V_frag (the faithful readout cell) AND run the matching producer (SubTileLoaderLDS_st_8x32[v_full_v227=True] / MlaPrefillV2Core._dma_v[v_full_v227=True], the W); the three compose to the standard PV fragment (numerically equivalent). FP8 32x32x64 only. The default-on V LDS adapter for MlaPrefillV2; production MLA passes v_full_v227=False. -D v_full_v227.
  • ​v227_layout (Bool): Spell the v_full_v227 per-lane READ base via CuTe Layout Algebra (crd2idx over a per-bit Coord) instead of the hand-rolled runtime bit arithmetic (Bool, only consulted when v_full_v227 is True). SAME mapping, different spelling β€” the v227 base is bit-LINEAR over the bit-decomposed lane, so it is crd2idx(lane, Coord(2,2,2,2,2,2), Coord(8, 0x80, 0x820, 0x1040, 16, 0x410)) (the 2-bit field ((v0>>2)&3)* 0x820 splits to bit20x820 + bit30x1040). Mirrors the WRITE side (SubTileLoaderLDS_st_8x32[_v227_layout]); this is the no-underscore parameter form of the WRITE side's struct field _v227_layout, both driven by -D v227_layout. Numerically equivalent. -D v227_layout.

Returns:

UnsafePointer[Scalar[T], origin, address_space=AddressSpace.SHARED]

load_V_from_lane_base​

static def load_V_from_lane_base[layout_dst: TensorLayout](mut dst: TileTensor[T, layout_dst, MutUntrackedOrigin, address_space=AddressSpace.LOCAL], v_lane_base: UnsafePointer[Scalar[T], address_space=AddressSpace.SHARED])

FP8 32x32x64 V load consuming a pre-computed per-lane base pointer (hoisted by the caller). Each call site adds ONLY comptime per-cell offsets to v_lane_base β€” no per-call per-lane base recomputation, no src.ptr indirection.

Collapses the 27 distinct V ds_read base VGPRs at KV=128 (one per V-load Γ— slot toggle) toward the reference's ~3-base pattern (one per SMEM slot, carried across all calls from that slot). Codegen-equivalent to the equivalent in-call recomputation on the K=128 fast path (the AMDGPU backend recovers the same ds_read offset:imm) but makes the kernel-level hoist explicit at the source level.

load_V_frag​

static def load_V_frag[i_strip: Int, j_depth: Int, v_full_v227: Bool = False](v_lane_base: UnsafePointer[Scalar[T], address_space=AddressSpace.SHARED]) -> SIMD[T, (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]

Loads ONE V MFMA fragment (i_strip, j_depth) from the pre-computed per-lane V LDS base and returns it as a SIMD value β€” the single-fragment factoring of load_V_from_lane_base's FP8 32x32x64 inner (i, j) cell (the 4 paired-lane ds_read_tr8_b64 joined into one SIMD[FP8, 32]).

Returning a SIMD value (rather than writing a register tile) is the V counterpart of load_K_frag β€” it lets the caller stream V fragments through the SAME rotating in-register band K streamed through (K and V have disjoint lifetimes: K is consumed in the QK MFMAs before softmax, V in the PV MFMAs after, so the band's registers are free for V by PV time). Each slot is then a plain SSA value, sidestepping the strided-register-sub-view write that lands on the wrong VGPRs at a non-zero offset, and avoids materializing the whole 64-VGPR V_LAYOUT tile held live across the softmax/exp/FP8 clusters (the reference never materializes V β€” it transpose-reads V fragment-at-a-time through its reused v[28:59] band). The fragment feeds mma_PV / gpu_mma directly (the PV A-operand).

i_strip is the K-direction base tile (i_strip*MMA_K keys per strip; 0 .. KV_BLOCK/MMA_K); j_depth is the depth tile (j_depth*MMA_N cols; 0 .. DEPTH/MMA_N). Both enter only as comptime ds_read_tr8_b64 offset: immediates, so across a caller's unrolled per-fragment stream LLVM CSEs v_lane_base to a single base-register set (the reference's v227 single-base V pattern) β€” the same hoist load_V_from_lane_base already relies on.

Constraints:

FP8 32x32x64 path only (MMA_K == 64).

Parameters:

  • ​i_strip (Int): K-direction base tile index.
  • ​j_depth (Int): Depth tile index.
  • ​v_full_v227 (Bool): Reference v227 V adapter READ cell (Bool). Default False β†’ OUR st_8x32 cell decode (byte-identical). When True, uses the faithful reference readout cell offset i_strip*0x2080 + j_depth*0x20 + r*0x100 (r = subread 0..3) on top of the v227 per-lane base (precompute_v_lane_base[v_full_v227=True]). This is the R of the adapter W∘R pair; the producer (SubTileLoaderLDS_st_8x32[v_full_v227=True] / MlaPrefillV2Core._dma_v[v_full_v227=True]) MUST set v_full_v227=True too, or V scrambles. The two compose to the standard PV fragment, bank-conflict-free (v227's per-tr8-cycle bank-quadrant bijection). FP8 32x32x64 only.

Returns:

SIMD[T, (Int((mul Int(128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(64) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) else Int(16), Int(16) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 73) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 74) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 75) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 76) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 77) or (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<_std::_builtin::_dtype::_DType> T, "_mlir_value">>, 78) and config.fp8_mma_k_128 else Int(32))) // Int(64))]