IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MLAConfig

struct MLAConfig[qkv_dtype: DType, *, rope_gmem_dtype: DType, rope_mma_dtype: DType, scale_dtype: DType = DType.invalid]

Fields​

  • ​fa4_config (FA4Config[qkv_dtype, rope_dtype=rope_mma_dtype, scale_dtype=scale_dtype]):
  • ​MMA_M (Int):
  • ​BM (Int):
  • ​BN (Int):
  • ​BK0 (Int):
  • ​BK1 (Int):
  • ​qk_depth (Int):
  • ​rope_depth (Int):
  • ​nope_depth (Int):
  • ​cache_depth (Int):
  • ​padded_qk_depth (Int):
  • ​group (Int):
  • ​num_q_heads (Int):
  • ​num_kv_heads (Int):
  • ​TMEM_S1 (Int):
  • ​TMEM_O0 (Int):
  • ​TMEM_O1 (Int):
  • ​TMEM_P0 (Int):
  • ​TMEM_P1 (Int):
  • ​TMEM_C0 (Int):
  • ​TMEM_C1 (Int):
  • ​tmem_used (Int):
  • ​num_kv_stages (Int):
  • ​num_qk_stages (Int):
  • ​num_pv_stages (Int):
  • ​smem_used (Int):
  • ​qkv_swizzle_mode (TensorMapSwizzle):
  • ​rope_mma_swizzle_mode (TensorMapSwizzle):
  • ​rope_gmem_swizzle_mode (TensorMapSwizzle):
  • ​output_swizzle_mode (TensorMapSwizzle):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

mbar_size​

comptime mbar_size = size_of[DType.int64]()

num_correction_cols​

comptime num_correction_cols = 1

num_threads​

comptime num_threads = 512

qkv_dtype_size​

comptime qkv_dtype_size = size_of[qkv_dtype]()

rope_gmem_dtype_size​

comptime rope_gmem_dtype_size = size_of[rope_gmem_dtype]()

rope_mma_dtype_size​

comptime rope_mma_dtype_size = size_of[rope_mma_dtype]()

sm100_smem_carveout​

comptime sm100_smem_carveout = (B200 - 1024)

sm100_tmem_cols​

comptime sm100_tmem_cols = 512

TMEM_S0​

comptime TMEM_S0 = 0

Methods​

__init__​

def __init__(*, num_q_heads: Int, group: Int, depth: Int, page_size: Int) -> Self

num_qo​

def num_qo(self) -> Int

Returns:

Int

num_rope_buffers​

def num_rope_buffers(self) -> Int

Returns:

Int

supported​

def supported(self) -> Bool

Returns:

Bool

correction_smem_elements​

def correction_smem_elements(self) -> Int

Returns:

Int

num_active_warps_per_group​

def num_active_warps_per_group(self) -> Int

Returns:

Int

num_active_threads_per_group​

def num_active_threads_per_group(self) -> Int

Returns:

Int