Mojo struct
MLAConfig
struct MLAConfig[qkv_dtype: DType, *, rope_gmem_dtype: DType, rope_mma_dtype: DType, scale_dtype: DType = DType.invalid]
Fields
- fa4_config (
FA4Config[qkv_dtype, rope_dtype=rope_mma_dtype, scale_dtype=scale_dtype]): - MMA_M (
Int): - BM (
Int): - BN (
Int): - BK0 (
Int): - BK1 (
Int): - qk_depth (
Int): - rope_depth (
Int): - nope_depth (
Int): - cache_depth (
Int): - padded_qk_depth (
Int): - group (
Int): - num_q_heads (
Int): - num_kv_heads (
Int): - TMEM_S1 (
Int): - TMEM_O0 (
Int): - TMEM_O1 (
Int): - TMEM_P0 (
Int): - TMEM_P1 (
Int): - TMEM_C0 (
Int): - TMEM_C1 (
Int): - tmem_used (
Int): - num_kv_stages (
Int): - num_qk_stages (
Int): - num_pv_stages (
Int): - smem_used (
Int): - split_m (
Bool): - qkv_swizzle_mode (
TensorMapSwizzle): - rope_mma_swizzle_mode (
TensorMapSwizzle): - rope_gmem_swizzle_mode (
TensorMapSwizzle): - output_swizzle_mode (
TensorMapSwizzle):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
mbar_size
comptime mbar_size = size_of[DType.int64]()
num_correction_cols
comptime num_correction_cols = 1
num_threads
comptime num_threads = 512
qkv_dtype_size
comptime qkv_dtype_size = size_of[qkv_dtype]()
rope_gmem_dtype_size
comptime rope_gmem_dtype_size = size_of[rope_gmem_dtype]()
rope_mma_dtype_size
comptime rope_mma_dtype_size = size_of[rope_mma_dtype]()
sm100_smem_carveout
comptime sm100_smem_carveout = (B200 - 1024)
sm100_tmem_cols
comptime sm100_tmem_cols = 512
TMEM_S0
comptime TMEM_S0 = 0
Methods
__init__
__init__(*, num_q_heads: Int, group: Int, depth: Int, page_size: Int) -> Self
num_qo
num_rope_buffers
supported
correction_smem_elements
num_active_warps_per_group
num_active_threads_per_group
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!