For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MLAConfig
struct MLAConfig[qkv_dtype: DType, *, rope_gmem_dtype: DType, rope_mma_dtype: DType, scale_dtype: DType = DType.invalid]
Fieldsβ
- βfa4_config (
FA4Config[qkv_dtype, rope_dtype=rope_mma_dtype, scale_dtype=scale_dtype]): - βMMA_M (
Int): - βBM (
Int): - βBN (
Int): - βBK0 (
Int): - βBK1 (
Int): - βqk_depth (
Int): - βrope_depth (
Int): - βnope_depth (
Int): - βcache_depth (
Int): - βpadded_qk_depth (
Int): - βgroup (
Int): - βnum_q_heads (
Int): - βnum_kv_heads (
Int): - βTMEM_S1 (
Int): - βTMEM_O0 (
Int): - βTMEM_O1 (
Int): - βTMEM_P0 (
Int): - βTMEM_P1 (
Int): - βtmem_used (
Int): - βnum_kv_stages (
Int): - βnum_qk_stages (
Int): - βnum_pv_stages (
Int): - βsmem_used (
Int): - βqkv_swizzle_mode (
TensorMapSwizzle): - βrope_mma_swizzle_mode (
TensorMapSwizzle): - βrope_gmem_swizzle_mode (
TensorMapSwizzle): - βoutput_swizzle_mode (
TensorMapSwizzle):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
mbar_sizeβ
comptime mbar_size = size_of[DType.int64]()
num_correction_colsβ
comptime num_correction_cols = 1
num_threadsβ
comptime num_threads = Int(512)
qkv_dtype_sizeβ
comptime qkv_dtype_size = size_of[qkv_dtype]()
rope_gmem_dtype_sizeβ
comptime rope_gmem_dtype_size = size_of[rope_gmem_dtype]()
rope_mma_dtype_sizeβ
comptime rope_mma_dtype_size = size_of[rope_mma_dtype]()
sm100_smem_carveoutβ
comptime sm100_smem_carveout = (GPUInfo.from_family(AcceleratorArchitectureFamily(Int(32), Int(2048), Int(233472), Int(65536), Int(1024)), StringSlice("B200"), Vendor(Int8(2)), StringSlice("cuda"), StringSlice("blackwell"), SIMD(10), StringSlice("sm_100a"), Int(148)) - Int(1024))
sm100_tmem_colsβ
comptime sm100_tmem_cols = 512
TMEM_S0β
comptime TMEM_S0 = Int(0)
Methodsβ
__init__β
def __init__(*, num_q_heads: Int, group: Int, depth: Int, page_size: Int, num_qo: Int = Int(2)) -> Self
num_qoβ
q_tile_rowsβ
def q_tile_rows(self) -> Int
Rows per Q TMA tile / per-half MMA β BM // num_qo.
128 in both modes: one of two BM=256 halves in 2Q, the single full BM=128 tile in 1Q. The Q (and per-token q_scale) TMA boxes and the ragged output store all fold to this value, which is why their op types match across the 1Q/2Q configs.
Returns:
with_num_qoβ
def with_num_qo(self, num_qo: Int) -> Self
Reconstruct this config with a different num_qo (single-CTA).
Mirrors FA4Config.with_num_qo, but simpler: MLA pins
num_qk_stages == 1 (is_mla), so there is no staging knob to
match between the 1Q and 2Q variants.
switch_1q_configβ
def switch_1q_config(self) -> Self
The 1Q variant used by the in-kernel per-sequence 1Q/2Q switch.
Identical to with_num_qo(1) (see with_num_qo for why MLA has
no staging-pinning concern, unlike FA4Config.switch_1q_config).
can_switch_to_1qβ
def can_switch_to_1q(self) -> Bool
Whether a 2Q-launched kernel may dispatch to the 1Q body at runtime.
True only when this is a 2Q config AND the 1Q variant is valid.
The TMA-op types fold between the two configs by construction:
the Q TMA / ragged-store BM // num_qo is 128 in both modes, and
the K_nope/K_rope/V TMA shapes are BM-independent (BN's formula
does not reference num_qo).
Returns:
launch_smem_usedβ
def launch_smem_used(self) -> Int
Dynamic smem to reserve when launching this config's kernel.
When the launched kernel may dispatch to the 1Q body at runtime
(can_switch_to_1q()), it constructs the 1Q SM100AttentionSMem
over the same dynamic smem region, so the launch must reserve the
max of both footprints. Otherwise this is just smem_used.
Returns:
prefer_1qβ
def prefer_1q(self, max_prompt_len: UInt32, num_partitions: UInt32, batch_size: UInt32, sm_count: Int) -> Bool
Runtime 1Q-vs-2Q grid heuristic for a 2Q config (mirrors the MHA heuristic in dispatch.mojo): prefer 1Q when (a) max_prompt_len fits a single 1Q tile (q_tile_rows()), so 2Q's BM=256 would waste >= 50% of Q rows, or (b) the unclamped 2Q grid only fills <= half the SMs, so halving BM doubles the grid without oversubscribing.
Returns:
num_rope_buffersβ
supportedβ
correction_smem_elementsβ
num_active_warps_per_groupβ
num_active_threads_per_groupβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!