IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MhaConfigV2

struct MhaConfigV2

Shape configuration for MhaPrefillV2.

Single source of truth for the shape parameters that drive register-tile layouts, SMEM sub-block geometry, grid dimensions, and IGLP scheduling. Lives in mha_mma_op so both MhaMmaOp and MhaPrefillV2 can take it as a parameter without a circular import.

Fields​

  • ​q_block_size (Int): Q rows per warp.
  • ​kv_block (Int): K/V rows per tile (64 at d=128).
  • ​depth (Int): Head depth.
  • ​num_heads (Int): Q num_heads.
  • ​num_kv_heads (Int): K/V num_heads. 1 (full GQA) or equal to num_heads (MHA); other ratios need a stride-aware DMA loader (TODO).
  • ​num_warps (Int): Warps per block.
  • ​rescale_threshold (Float32): Lazy-rescale threshold in log2 units of the running max. Above this, o_reg / norm_vec are rescaled by exp2(max_prev - max_new); below this the rescale is deferred β€” the residual contribution from att_block is bounded by exp2(-rescale_threshold) of the new max.
  • ​dtype (DType): Element dtype of the Q / K / V input tiles. DType.bfloat16 for the production BF16 prefill path; DType.float8_e4m3fn for the FP8 prefill path. Comptime-dispatched throughout MhaMmaOp (MFMA shape, swizzle, SMEM sub-block dims) and MhaPrefillV2 (register tiles, SMEM slots, cooperative loaders).
  • ​output_dtype (DType): Element dtype of the output tile o. FP32 by default; BF16 for production inference where the dispatcher holds a BF16 output buffer. The cast from the FP32 accumulator happens per-lane inside _store_o_to_gmem.
  • ​fp8_mma_k_128 (Bool): When True and dtype is an FP8 type, select the v_mfma_scale_f32_16x16x128_f8f6f4 MFMA shape (MMA_M=MMA_N=16, MMA_K=128) instead of the default v_mfma_scale_f32_32x32x64_f8f6f4 (MMA_M=MMA_N=32, MMA_K=64). The 16Γ—16Γ—128 shape issues every 16 cycles vs 32 cycles for 32Γ—32Γ—64 β€” mirrors the FP8 ping-pong / 4-wave matmul choice (see amd_ping_pong_matmul.mojo:716-737). BF16 is unaffected β€” it always uses 32Γ—32Γ—16. Defaults to False (today's 32Γ—32Γ—64 path).

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable

Methods​

__init__​

def __init__(out self, *, q_block_size: Int, kv_block: Int, depth: Int, num_heads: Int, num_kv_heads: Int, num_warps: Int = 8, rescale_threshold: Float32 = 8, dtype: DType = DType.bfloat16, output_dtype: DType = DType.float32, fp8_mma_k_128: Bool = False)