Skip to main content

Mojo struct

MLASmemStorage

struct MLASmemStorage[qkv_dtype: DType, rope_dtype: DType, num_mbars: Int, config: MLAConfig]

Fields

  • q_smem (InlineArray[UInt8, MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].q_bytes]):
  • kv_smem (InlineArray[UInt8, MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].kv_bytes]):
  • q_scale_smem (InlineArray[UInt8, MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].q_scale_bytes]):
  • k_scale_smem (InlineArray[UInt8, MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].k_scale_bytes]):
  • correction_smem (InlineArray[Float32, MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].correction_smem_size]):
  • mbar_base (InlineArray[SharedMemBarrier, num_mbars]):
  • tmem_addr (InlineArray[UInt32, 1]):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

correction_smem_size

comptime correction_smem_size = config.correction_smem_elements()

k_scale_bytes

comptime k_scale_bytes = (config * size_of[DType.float32]())

kv_bytes

comptime kv_bytes = (MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].kv_nope_bytes + MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].kv_rope_bytes)

kv_nope_bytes

comptime kv_nope_bytes = (((config * config) * size_of[qkv_dtype]()) * MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].num_kv_stages)

kv_rope_bytes

comptime kv_rope_bytes = (((config * config) * size_of[rope_dtype]()) * MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].num_kv_stages)

num_kv_stages

comptime num_kv_stages = (config * config)

q_bytes

comptime q_bytes = (MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].q_nope_bytes + MLASmemStorage[qkv_dtype, rope_dtype, num_mbars, config].q_rope_bytes)

q_nope_bytes

comptime q_nope_bytes = ((config * config) * size_of[qkv_dtype]())

q_rope_bytes

comptime q_rope_bytes = ((config * config) * size_of[rope_dtype]())

q_scale_bytes

comptime q_scale_bytes = (config * size_of[DType.float32]())

Was this page helpful?