Mojo struct
Depth512AttentionSMem
struct Depth512AttentionSMem[qkv_dtype: DType, //, config: Depth512SM100Config[qkv_dtype]]
Shared memory layout manager for depth=512 pair-CTA attention kernels.
Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, P, KV pipeline, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.
Parametersβ
- βqkv_dtype (
DType): Element type of Q/K/V data in shared memory. - βconfig (
Depth512SM100Config[qkv_dtype]): Depth512SM100 configuration (tile sizes, depths, staging counts, etc.). All fields are comptime-accessible when the config is a comptime parameter.
Fieldsβ
- βbase (
UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
correction_byte_offsetβ
comptime correction_byte_offset = (Depth512AttentionSMem[config].kv_byte_offset + Depth512AttentionSMem[config].kv_total_bytes)
correction_bytesβ
comptime correction_bytes = (config * size_of[DType.float32]())
correction_offsetβ
comptime correction_offset = SIMD((Depth512AttentionSMem[config].correction_byte_offset // size_of[DType.float32]()))
kv_byte_offsetβ
comptime kv_byte_offset = (Depth512AttentionSMem[config].p_byte_offset + Depth512AttentionSMem[config].p_bytes)
kv_offsetβ
comptime kv_offset = SIMD((Depth512AttentionSMem[config].kv_byte_offset // size_of[qkv_dtype]()))
kv_stage_bytesβ
comptime kv_stage_bytes = (((config // 2) * config) * size_of[qkv_dtype]())
kv_total_bytesβ
comptime kv_total_bytes = (config * Depth512AttentionSMem[config].kv_stage_bytes)
mbar_byte_offsetβ
comptime mbar_byte_offset = (Depth512AttentionSMem[config].correction_byte_offset + Depth512AttentionSMem[config].correction_bytes)
mbar_bytesβ
comptime mbar_bytes = (Depth512AttentionSMem[config].total_mbars * size_of[SharedMemBarrier]())
mbar_offsetβ
comptime mbar_offset = SIMD((Depth512AttentionSMem[config].mbar_byte_offset // size_of[SharedMemBarrier]()))
num_fixed_mbarsβ
comptime num_fixed_mbars = 10 if config.split_o else 8
num_kv_mbarsβ
comptime num_kv_mbars = (2 * config)
p_byte_offsetβ
comptime p_byte_offset = Depth512AttentionSMem[config].q_bytes
p_bytesβ
comptime p_bytes = ((config * config) * size_of[qkv_dtype]())
p_offsetβ
comptime p_offset = SIMD((Depth512AttentionSMem[config].p_byte_offset // size_of[qkv_dtype]()))
q_byte_offsetβ
comptime q_byte_offset = 0
q_bytesβ
comptime q_bytes = ((config * config) * size_of[qkv_dtype]())
q_offsetβ
comptime q_offset = Int32(0)
tmem_addr_byte_offsetβ
comptime tmem_addr_byte_offset = (Depth512AttentionSMem[config].mbar_byte_offset + Depth512AttentionSMem[config].mbar_bytes)
total_mbarsβ
comptime total_mbars = (Depth512AttentionSMem[config].num_fixed_mbars + Depth512AttentionSMem[config].num_kv_mbars)
Methodsβ
__init__β
__init__() -> Self
Obtain the base pointer from the kernel's dynamic shared memory.
q_smemβ
q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the Q region (offset 0).
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
o_smemβ
o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
Same physical memory as Q, bitcast to the output element type.
Returns:
UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
p_smemβ
p_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the P buffer region for SS MMA P@V.
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
kv_smem_baseβ
kv_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the KV pipeline region (stage 0).
Each stage holds (BN//2)*BK0 elements of qkv_dtype per SM. Interpreted as K (BN//2 Γ BK0) during Q@K' or V (BK1 Γ ov_depth//2) during P@V. The pair-CTA MMA reads from both SMs.
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
correction_smemβ
correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the correction region (BM Float32 elements).
Returns:
UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
mbar_baseβ
mbar_base(self) -> MBarType
Base of the barrier region.
Layout: num_fixed_mbars fixed barriers (10 when split_o, 8 otherwise) followed by 2*num_kv_stages KV pipeline barriers. The Depth512MBars struct wraps this pointer with named accessors for each barrier role.
Returns:
MBarType
tmem_addr_ptrβ
tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Pointer to the single UInt32 storing the TMEM address.
Returns:
UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
smem_sizeβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!