IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MLAKVLayouts

struct MLAKVLayouts[k_nope_dtype: DType, k_rope_dtype: DType, kv_scale_dtype: DType, config: MLAConfig[config.qkv_dtype, rope_gmem_dtype=config.rope_gmem_dtype, rope_mma_dtype=config.rope_mma_dtype, scale_dtype=config.scale_dtype]]

Comptime layout and size metadata for MLA K/V tiles.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

k_bytes​

comptime k_bytes = (Int((mul (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()), (config // Int(8)), (Int(128) // (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())), size_of[k_nope_dtype](), 8)) + Int((mul (config.rope_gmem_swizzle_mode.bytes() // size_of[k_rope_dtype]()), (config // Int(8)), (Int(64) // (config.rope_gmem_swizzle_mode.bytes() // size_of[k_rope_dtype]())), size_of[k_rope_dtype](), 8)))

k_elements​

comptime k_elements = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_nope_tma_layout + MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_rope_tma_layout)

k_nope_bytes​

comptime k_nope_bytes = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_nope_tma_layout * size_of[k_nope_dtype]())

k_nope_tma_layout​

comptime k_nope_tma_layout = Layout[*?, *?].static_product

k_rope_bytes​

comptime k_rope_bytes = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_rope_tma_layout * size_of[k_rope_dtype]())

k_rope_tma_layout​

comptime k_rope_tma_layout = Layout[*?, *?].static_product

k_tma_layout​

comptime k_tma_layout = Layout[*?, *?].static_product

KPairType​

comptime KPairType = TMADestination[k_nope_dtype, MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_tma_layout]

SMemType​

comptime SMemType = UnsafePointer[Scalar[k_nope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

v_bytes​

comptime v_bytes = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].v_elements * size_of[k_nope_dtype]())

v_elements​

comptime v_elements = MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].v_tma_layout

v_tma_layout​

comptime v_tma_layout = Layout[*?, *?].static_product

VPairType​

comptime VPairType = TMADestination[k_nope_dtype, MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].v_tma_layout]