For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MhaConfigV2
struct MhaConfigV2
Shape configuration for MhaPrefillV2.
Single source of truth for the shape parameters that drive
register-tile layouts, SMEM sub-block geometry, grid dimensions,
and IGLP scheduling. Lives in mha_mma_op so both MhaMmaOp
and MhaPrefillV2 can take it as a parameter without a circular
import.
Fieldsβ
- βq_block_size (
Int): Q rows per warp. - βkv_block (
Int): K/V rows per tile (64 at d=128). - βdepth (
Int): Head depth. - βnum_heads (
Int): Q num_heads. - βnum_kv_heads (
Int): K/V num_heads.1(full GQA) or equal tonum_heads(MHA); other ratios need a stride-aware DMA loader (TODO). - βnum_warps (
Int): Warps per block. - βrescale_threshold (
Float32): Lazy-rescale threshold in log2 units of the running max. Above this,o_reg/norm_vecare rescaled byexp2(max_prev - max_new); below this the rescale is deferred β the residual contribution fromatt_blockis bounded byexp2(-rescale_threshold)of the new max. - βdtype (
DType): Element dtype of the Q / K / V input tiles.DType.bfloat16for the production BF16 prefill path;DType.float8_e4m3fnfor the FP8 prefill path. Comptime-dispatched throughoutMhaMmaOp(MFMA shape, swizzle, SMEM sub-block dims) andMhaPrefillV2(register tiles, SMEM slots, cooperative loaders). - βoutput_dtype (
DType): Element dtype of the output tileo. FP32 by default; BF16 for production inference where the dispatcher holds a BF16 output buffer. The cast from the FP32 accumulator happens per-lane inside_store_o_to_gmem. - βfp8_mma_k_128 (
Bool): When True anddtypeis an FP8 type, select thev_mfma_scale_f32_16x16x128_f8f6f4MFMA shape (MMA_M=MMA_N=16, MMA_K=128) instead of the defaultv_mfma_scale_f32_32x32x64_f8f6f4(MMA_M=MMA_N=32, MMA_K=64). The 16Γ16Γ128 shape issues every 16 cycles vs 32 cycles for 32Γ32Γ64 β mirrors the FP8 ping-pong / 4-wave matmul choice (seeamd_ping_pong_matmul.mojo:716-737). BF16 is unaffected β it always uses 32Γ32Γ16. Defaults to False (today's 32Γ32Γ64 path).
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable
Methodsβ
__init__β
def __init__(out self, *, q_block_size: Int, kv_block: Int, depth: Int, num_heads: Int, num_kv_heads: Int, num_warps: Int = 8, rescale_threshold: Float32 = 8, dtype: DType = DType.bfloat16, output_dtype: DType = DType.float32, fp8_mma_k_128: Bool = False)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!