Mojo struct
MLAConfig
struct MLAConfig[qkv_dtype: DType, *, rope_gmem_dtype: DType, rope_mma_dtype: DType, scale_dtype: DType = DType.invalid]
Fieldsβ
- βfa4_config (
FA4Config[qkv_dtype, rope_dtype=rope_mma_dtype, scale_dtype=scale_dtype]): - βMMA_M (
Int): - βBM (
Int): - βBN (
Int): - βBK0 (
Int): - βBK1 (
Int): - βqk_depth (
Int): - βrope_depth (
Int): - βnope_depth (
Int): - βcache_depth (
Int): - βpadded_qk_depth (
Int): - βgroup (
Int): - βnum_q_heads (
Int): - βnum_kv_heads (
Int): - βTMEM_S1 (
Int): - βTMEM_O0 (
Int): - βTMEM_O1 (
Int): - βTMEM_P0 (
Int): - βTMEM_P1 (
Int): - βTMEM_C0 (
Int): - βTMEM_C1 (
Int): - βtmem_used (
Int): - βnum_kv_stages (
Int): - βnum_qk_stages (
Int): - βnum_pv_stages (
Int): - βsmem_used (
Int): - βqkv_swizzle_mode (
TensorMapSwizzle): - βrope_mma_swizzle_mode (
TensorMapSwizzle): - βrope_gmem_swizzle_mode (
TensorMapSwizzle): - βoutput_swizzle_mode (
TensorMapSwizzle):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
mbar_sizeβ
comptime mbar_size = size_of[DType.int64]()
num_correction_colsβ
comptime num_correction_cols = 1
num_threadsβ
comptime num_threads = 512
qkv_dtype_sizeβ
comptime qkv_dtype_size = size_of[qkv_dtype]()
rope_gmem_dtype_sizeβ
comptime rope_gmem_dtype_size = size_of[rope_gmem_dtype]()
rope_mma_dtype_sizeβ
comptime rope_mma_dtype_size = size_of[rope_mma_dtype]()
sm100_smem_carveoutβ
comptime sm100_smem_carveout = (B200 - 1024)
sm100_tmem_colsβ
comptime sm100_tmem_cols = 512
TMEM_S0β
comptime TMEM_S0 = 0
Methodsβ
__init__β
__init__(*, num_q_heads: Int, group: Int, depth: Int, page_size: Int) -> Self
num_qoβ
num_rope_buffersβ
supportedβ
correction_smem_elementsβ
num_active_warps_per_groupβ
num_active_threads_per_groupβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!