For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Depth512AttentionSMem
struct Depth512AttentionSMem[qkv_dtype: DType, //, config: Depth512SM100Config[qkv_dtype]]
Shared memory layout manager for depth=512 pair-CTA attention kernels.
Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, P, KV pipeline, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.
Parametersβ
- βqkv_dtype (
DType): Element type of Q/K/V data in shared memory. - βconfig (
Depth512SM100Config[qkv_dtype]): Depth512SM100 configuration (tile sizes, depths, staging counts, etc.). All fields are comptime-accessible when the config is a comptime parameter.
Fieldsβ
- βbase (
UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
correction_byte_offsetβ
comptime correction_byte_offset = (Int((add (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth))) + Int((mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages)))
correction_bytesβ
comptime correction_bytes = (config * size_of[DType.float32]())
correction_offsetβ
comptime correction_offset = SIMD((Int((add (mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages), (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth))) // size_of[DType.float32]()))
kv_byte_offsetβ
comptime kv_byte_offset = (Int((mul size_of[qkv_dtype](), config.BM, config.qk_depth)) + Int((mul size_of[qkv_dtype](), config.BM, config.BN)))
kv_offsetβ
comptime kv_offset = SIMD((Int((add (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth))) // size_of[qkv_dtype]()))
kv_stage_bytesβ
comptime kv_stage_bytes = (Int((mul (config // Int(2)), config.BK0)) * size_of[qkv_dtype]())
kv_total_bytesβ
comptime kv_total_bytes = (config * Int((mul (config // Int(2)), size_of[qkv_dtype](), config.BK0)))
mbar_byte_offsetβ
comptime mbar_byte_offset = (Int((add (mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages), (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth))) + Int((mul size_of[DType.float32](), config.BM)))
mbar_bytesβ
comptime mbar_bytes = (Int((add (mul config.num_kv_stages, 2), Int(10) if config.split_o else Int(8))) * size_of[SharedMemBarrier]())
mbar_offsetβ
comptime mbar_offset = SIMD((Int((add (mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages), (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth), (mul size_of[DType.float32](), config.BM))) // size_of[SharedMemBarrier]()))
num_fixed_mbarsβ
comptime num_fixed_mbars = Int(10) if config.split_o else Int(8)
num_kv_mbarsβ
comptime num_kv_mbars = (Int(2) * config)
p_byte_offsetβ
comptime p_byte_offset = Depth512AttentionSMem[config].q_bytes
p_bytesβ
comptime p_bytes = (Int((mul config.BM, config.BN)) * size_of[qkv_dtype]())
p_offsetβ
comptime p_offset = SIMD((Int((mul size_of[qkv_dtype](), config.BM, config.qk_depth)) // size_of[qkv_dtype]()))
q_byte_offsetβ
comptime q_byte_offset = Int(0)
q_bytesβ
comptime q_bytes = (Int((mul config.BM, config.qk_depth)) * size_of[qkv_dtype]())
q_offsetβ
comptime q_offset = Int32(0)
tmem_addr_byte_offsetβ
comptime tmem_addr_byte_offset = (Int((add (mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages), (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth), (mul size_of[DType.float32](), config.BM))) + Int((add (mul size_of[SharedMemBarrier](), config.num_kv_stages, 2), (mul size_of[SharedMemBarrier](), Int(10) if config.split_o else Int(8)))))
total_mbarsβ
comptime total_mbars = (Int(10) if config.split_o else Int(8) + Int((mul config.num_kv_stages, 2)))
Methodsβ
__init__β
def __init__() -> Self
Obtain the base pointer from the kernel's dynamic shared memory.
q_smemβ
def q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the Q region (offset 0).
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
o_smemβ
def o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
Same physical memory as Q, bitcast to the output element type.
Returns:
UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
p_smemβ
def p_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the P buffer region for SS MMA P@V.
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
kv_smem_baseβ
def kv_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the KV pipeline region (stage 0).
Each stage holds (BN//2)*BK0 elements of qkv_dtype per SM. Interpreted as K (BN//2 Γ BK0) during Q@K' or V (BK1 Γ ov_depth//2) during P@V. The pair-CTA MMA reads from both SMs.
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
correction_smemβ
def correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the correction region (BM Float32 elements).
Returns:
UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
mbar_baseβ
def mbar_base(self) -> MBarType
Base of the barrier region.
Layout: num_fixed_mbars fixed barriers (10 when split_o, 8 otherwise) followed by 2*num_kv_stages KV pipeline barriers. The Depth512MBars struct wraps this pointer with named accessors for each barrier role.
Returns:
MBarType
tmem_addr_ptrβ
def tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Pointer to the single UInt32 storing the TMEM address.
Returns:
UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
smem_sizeβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!