IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MLAConfig

struct MLAConfig[qkv_dtype: DType, *, rope_gmem_dtype: DType, rope_mma_dtype: DType, scale_dtype: DType = DType.invalid]

Fields​

  • ​fa4_config (FA4Config[qkv_dtype, rope_dtype=rope_mma_dtype, scale_dtype=scale_dtype]):
  • ​MMA_M (Int):
  • ​BM (Int):
  • ​BN (Int):
  • ​BK0 (Int):
  • ​BK1 (Int):
  • ​qk_depth (Int):
  • ​rope_depth (Int):
  • ​nope_depth (Int):
  • ​cache_depth (Int):
  • ​padded_qk_depth (Int):
  • ​group (Int):
  • ​num_q_heads (Int):
  • ​num_kv_heads (Int):
  • ​TMEM_S1 (Int):
  • ​TMEM_O0 (Int):
  • ​TMEM_O1 (Int):
  • ​TMEM_P0 (Int):
  • ​TMEM_P1 (Int):
  • ​tmem_used (Int):
  • ​num_kv_stages (Int):
  • ​num_qk_stages (Int):
  • ​num_pv_stages (Int):
  • ​smem_used (Int):
  • ​qkv_swizzle_mode (TensorMapSwizzle):
  • ​rope_mma_swizzle_mode (TensorMapSwizzle):
  • ​rope_gmem_swizzle_mode (TensorMapSwizzle):
  • ​output_swizzle_mode (TensorMapSwizzle):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

mbar_size​

comptime mbar_size = size_of[DType.int64]()

num_correction_cols​

comptime num_correction_cols = 1

num_threads​

comptime num_threads = Int(512)

qkv_dtype_size​

comptime qkv_dtype_size = size_of[qkv_dtype]()

rope_gmem_dtype_size​

comptime rope_gmem_dtype_size = size_of[rope_gmem_dtype]()

rope_mma_dtype_size​

comptime rope_mma_dtype_size = size_of[rope_mma_dtype]()

sm100_smem_carveout​

comptime sm100_smem_carveout = (GPUInfo.from_family(AcceleratorArchitectureFamily(Int(32), Int(2048), Int(233472), Int(65536), Int(1024)), StringSlice("B200"), Vendor(Int8(2)), StringSlice("cuda"), StringSlice("blackwell"), SIMD(10), StringSlice("sm_100a"), Int(148)) - Int(1024))

sm100_tmem_cols​

comptime sm100_tmem_cols = 512

TMEM_S0​

comptime TMEM_S0 = Int(0)

Methods​

__init__​

def __init__(*, num_q_heads: Int, group: Int, depth: Int, page_size: Int, num_qo: Int = Int(2)) -> Self

num_qo​

def num_qo(self) -> Int

Returns:

Int

q_tile_rows​

def q_tile_rows(self) -> Int

Rows per Q TMA tile / per-half MMA β€” BM // num_qo.

128 in both modes: one of two BM=256 halves in 2Q, the single full BM=128 tile in 1Q. The Q (and per-token q_scale) TMA boxes and the ragged output store all fold to this value, which is why their op types match across the 1Q/2Q configs.

Returns:

Int

with_num_qo​

def with_num_qo(self, num_qo: Int) -> Self

Reconstruct this config with a different num_qo (single-CTA).

Mirrors FA4Config.with_num_qo, but simpler: MLA pins num_qk_stages == 1 (is_mla), so there is no staging knob to match between the 1Q and 2Q variants.

switch_1q_config​

def switch_1q_config(self) -> Self

The 1Q variant used by the in-kernel per-sequence 1Q/2Q switch.

Identical to with_num_qo(1) (see with_num_qo for why MLA has no staging-pinning concern, unlike FA4Config.switch_1q_config).

can_switch_to_1q​

def can_switch_to_1q(self) -> Bool

Whether a 2Q-launched kernel may dispatch to the 1Q body at runtime.

True only when this is a 2Q config AND the 1Q variant is valid. The TMA-op types fold between the two configs by construction: the Q TMA / ragged-store BM // num_qo is 128 in both modes, and the K_nope/K_rope/V TMA shapes are BM-independent (BN's formula does not reference num_qo).

Returns:

Bool

launch_smem_used​

def launch_smem_used(self) -> Int

Dynamic smem to reserve when launching this config's kernel.

When the launched kernel may dispatch to the 1Q body at runtime (can_switch_to_1q()), it constructs the 1Q SM100AttentionSMem over the same dynamic smem region, so the launch must reserve the max of both footprints. Otherwise this is just smem_used.

Returns:

Int

prefer_1q​

def prefer_1q(self, max_prompt_len: UInt32, num_partitions: UInt32, batch_size: UInt32, sm_count: Int) -> Bool

Runtime 1Q-vs-2Q grid heuristic for a 2Q config (mirrors the MHA heuristic in dispatch.mojo): prefer 1Q when (a) max_prompt_len fits a single 1Q tile (q_tile_rows()), so 2Q's BM=256 would waste >= 50% of Q rows, or (b) the unclamped 2Q grid only fills <= half the SMs, so halving BM doubles the grid without oversubscribing.

Returns:

Bool

num_rope_buffers​

def num_rope_buffers(self) -> Int

Returns:

Int

supported​

def supported(self) -> Bool

Returns:

Bool

correction_smem_elements​

def correction_smem_elements(self) -> Int

Returns:

Int

num_active_warps_per_group​

def num_active_warps_per_group(self) -> Int

Returns:

Int

num_active_threads_per_group​

def num_active_threads_per_group(self) -> Int

Returns:

Int