IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

mha_mma_op

MHA MMA operator: shape constants, SMEM→register loaders, and MFMA dispatch used by MhaPrefillV2.

Supports two MFMA flavors on gfx950, comptime-selected on T:

  • BF16 (v_mfma_f32_32x32x16_bf16) — MMA_K=16, 8 BF16 elts/lane/base. K side uses the two-XOR worker-index swizzle and ds_read_b128; V side uses the identity swizzle and the ds_read_tr16_b64_warp transpose-load.
  • FP8 e4m3 (v_mfma_scale_f32_32x32x64_f8f6f4) — MMA_K=64, 32 FP8 elts/lane/base. K loader reuses the BF16 byte-level two-XOR swizzle — the byte-positional swizzle works for both sizes because the sub-block is 64 B-wide either way. FP8 V loader uses ds_read_tr8_b64 paired-lane transpose-loads, mirroring TiledMmaLoader.load_v_fp8_strip. mma_QK / mma_PV are dtype-generic — gpu_mma dispatches on SIMD operand sizes.

comptime values

ACC_ROW_OFFSETS_32x32

comptime ACC_ROW_OFFSETS_32x32 = SIMD(Int32(0), Int32(1), Int32(2), Int32(3), Int32(8), Int32(9), Int32(10), Int32(11), Int32(16), Int32(17), Int32(18), Int32(19), Int32(24), Int32(25), Int32(26), Int32(27), __list_literal__=NoneType(None))

Structs

  • MhaConfigV2: Shape configuration for MhaPrefillV2.
  • MhaMmaOp: Namespace-style struct holding the shape constants, register-tile layouts, and SMEM→register loaders for MhaPrefillV2. All call sites go through static methods on this struct.
  • MlaConfigV2: Shape configuration for MlaPrefillV2Core. Companion to MhaConfigV2.
  • MlaMmaOp: MLA MMA op for MlaPrefillV2 / MlaPrefillV2Core.