Skip to main content

Mojo struct

MLAConfig

struct MLAConfig[qkv_dtype: DType, *, rope_gmem_dtype: DType, rope_mma_dtype: DType, scale_dtype: DType = DType.invalid]

Fields​

  • ​fa4_config (FA4Config[qkv_dtype, rope_dtype=rope_mma_dtype, scale_dtype=scale_dtype]):
  • ​MMA_M (Int):
  • ​BM (Int):
  • ​BN (Int):
  • ​BK0 (Int):
  • ​BK1 (Int):
  • ​qk_depth (Int):
  • ​rope_depth (Int):
  • ​nope_depth (Int):
  • ​cache_depth (Int):
  • ​padded_qk_depth (Int):
  • ​group (Int):
  • ​num_q_heads (Int):
  • ​num_kv_heads (Int):
  • ​TMEM_S1 (Int):
  • ​TMEM_O0 (Int):
  • ​TMEM_O1 (Int):
  • ​TMEM_P0 (Int):
  • ​TMEM_P1 (Int):
  • ​TMEM_C0 (Int):
  • ​TMEM_C1 (Int):
  • ​tmem_used (Int):
  • ​num_kv_stages (Int):
  • ​num_qk_stages (Int):
  • ​num_pv_stages (Int):
  • ​smem_used (Int):
  • ​qkv_swizzle_mode (TensorMapSwizzle):
  • ​rope_mma_swizzle_mode (TensorMapSwizzle):
  • ​rope_gmem_swizzle_mode (TensorMapSwizzle):
  • ​output_swizzle_mode (TensorMapSwizzle):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

mbar_size​

comptime mbar_size = size_of[DType.int64]()

num_correction_cols​

comptime num_correction_cols = 1

num_threads​

comptime num_threads = 512

qkv_dtype_size​

comptime qkv_dtype_size = size_of[qkv_dtype]()

rope_gmem_dtype_size​

comptime rope_gmem_dtype_size = size_of[rope_gmem_dtype]()

rope_mma_dtype_size​

comptime rope_mma_dtype_size = size_of[rope_mma_dtype]()

sm100_smem_carveout​

comptime sm100_smem_carveout = (B200 - 1024)

sm100_tmem_cols​

comptime sm100_tmem_cols = 512

TMEM_S0​

comptime TMEM_S0 = 0

Methods​

__init__​

__init__(*, num_q_heads: Int, group: Int, depth: Int, page_size: Int) -> Self

num_qo​

num_qo(self) -> Int

Returns:

Int

num_rope_buffers​

num_rope_buffers(self) -> Int

Returns:

Int

supported​

supported(self) -> Bool

Returns:

Bool

correction_smem_elements​

correction_smem_elements(self) -> Int

Returns:

Int

num_active_warps_per_group​

num_active_warps_per_group(self) -> Int

Returns:

Int

num_active_threads_per_group​

num_active_threads_per_group(self) -> Int

Returns:

Int