Skip to main content

Mojo struct

MLAKVLayouts

struct MLAKVLayouts[k_nope_dtype: DType, k_rope_dtype: DType, kv_scale_dtype: DType, config: MLAConfig[config.qkv_dtype, rope_gmem_dtype=config.rope_gmem_dtype, rope_mma_dtype=config.rope_mma_dtype, scale_dtype=config.scale_dtype]]

Comptime layout and size metadata for MLA K/V tiles.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

k_bytes

comptime k_bytes = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_nope_bytes + MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_rope_bytes)

k_elements

comptime k_elements = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_nope_tma_layout + MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_rope_tma_layout)

k_nope_bytes

comptime k_nope_bytes = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_nope_tma_layout * size_of[k_nope_dtype]())

k_nope_tma_layout

comptime k_nope_tma_layout = Layout[Coord[ComptimeInt[8], ComptimeInt[(config // 8)]], Coord[ComptimeInt[(config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())], ComptimeInt[(128 // (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]], Coord[ComptimeInt[(config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())], ComptimeInt[(8 * (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (128 == (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())) else (config * (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]]].static_product

k_rope_bytes

comptime k_rope_bytes = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_rope_tma_layout * size_of[k_rope_dtype]())

k_rope_tma_layout

comptime k_rope_tma_layout = Layout[Coord[ComptimeInt[8], ComptimeInt[(config // 8)]], Coord[ComptimeInt[(config.rope_gmem_swizzle_mode.bytes() // size_of[k_rope_dtype]())], ComptimeInt[(64 // (config.rope_gmem_swizzle_mode.bytes() // size_of[k_rope_dtype]()))]], Coord[ComptimeInt[(config.rope_gmem_swizzle_mode.bytes() // size_of[k_rope_dtype]())], ComptimeInt[(8 * (config.rope_gmem_swizzle_mode.bytes() // size_of[k_rope_dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (64 == (config.rope_gmem_swizzle_mode.bytes() // size_of[k_rope_dtype]())) else (config * (config.rope_gmem_swizzle_mode.bytes() // size_of[k_rope_dtype]()))]]].static_product

k_tma_layout

comptime k_tma_layout = Layout[Coord[ComptimeInt[8], ComptimeInt[(config // 8)]], Coord[ComptimeInt[(config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())], ComptimeInt[(config // (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]], Coord[ComptimeInt[(config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())], ComptimeInt[(8 * (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (config == (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())) else (config * (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]]].static_product

KPairType

comptime KPairType = TMADestination[k_nope_dtype, MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].k_tma_layout]

SMemType

comptime SMemType = UnsafePointer[Scalar[k_nope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

v_bytes

comptime v_bytes = (MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].v_elements * size_of[k_nope_dtype]())

v_elements

comptime v_elements = MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].v_tma_layout

v_tma_layout

comptime v_tma_layout = Layout[Coord[ComptimeInt[(config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())], ComptimeInt[(128 // (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]], Coord[ComptimeInt[8], ComptimeInt[(config // 8)]], Coord[ComptimeInt[1], ComptimeInt[0 if (128 == (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())) else (config * (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]], Coord[ComptimeInt[(config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]())], ComptimeInt[(8 * (config.qkv_swizzle_mode.bytes() // size_of[k_nope_dtype]()))]]].static_product

VPairType

comptime VPairType = TMADestination[k_nope_dtype, MLAKVLayouts[k_nope_dtype, k_rope_dtype, kv_scale_dtype, config].v_tma_layout]

Was this page helpful?