Mojo struct
Depth512AttentionSMem
struct Depth512AttentionSMem[qkv_dtype: DType, //, config: Depth512SM100Config[qkv_dtype]]
Shared memory layout manager for depth=512 pair-CTA attention kernels.
Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, P, KV pipeline, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.
Parameters
- qkv_dtype (
DType): Element type of Q/K/V data in shared memory. - config (
Depth512SM100Config): Depth512SM100 configuration (tile sizes, depths, staging counts, etc.). All fields are comptime-accessible when the config is a comptime parameter.
Fields
- base (
UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
correction_byte_offset
comptime correction_byte_offset = (Depth512AttentionSMem[config].kv_byte_offset + Depth512AttentionSMem[config].kv_total_bytes)
correction_bytes
comptime correction_bytes = (Depth512SM100Config[qkv_dtype].BM * size_of[DType.float32]())
correction_offset
comptime correction_offset = SIMD((Depth512AttentionSMem[config].correction_byte_offset // size_of[DType.float32]()))
kv_byte_offset
comptime kv_byte_offset = (Depth512AttentionSMem[config].p_byte_offset + Depth512AttentionSMem[config].p_bytes)
kv_offset
comptime kv_offset = SIMD((Depth512AttentionSMem[config].kv_byte_offset // size_of[qkv_dtype]()))
kv_stage_bytes
comptime kv_stage_bytes = (((config // 2) * config) * size_of[qkv_dtype]())
kv_total_bytes
comptime kv_total_bytes = (config * Depth512AttentionSMem[config].kv_stage_bytes)
mbar_byte_offset
comptime mbar_byte_offset = (Depth512AttentionSMem[config].correction_byte_offset + Depth512AttentionSMem[config].correction_bytes)
mbar_bytes
comptime mbar_bytes = (Depth512AttentionSMem[config].total_mbars * size_of[SharedMemBarrier]())
mbar_offset
comptime mbar_offset = SIMD((Depth512AttentionSMem[config].mbar_byte_offset // size_of[SharedMemBarrier]()))
num_fixed_mbars
comptime num_fixed_mbars = 10
num_kv_mbars
comptime num_kv_mbars = (2 * config)
p_byte_offset
comptime p_byte_offset = Depth512AttentionSMem[config].q_bytes
p_bytes
comptime p_bytes = ((Depth512SM100Config[qkv_dtype].BM * config) * size_of[qkv_dtype]())
p_offset
comptime p_offset = SIMD((Depth512AttentionSMem[config].p_byte_offset // size_of[qkv_dtype]()))
q_byte_offset
comptime q_byte_offset = 0
q_bytes
comptime q_bytes = ((Depth512SM100Config[qkv_dtype].BM * config) * size_of[qkv_dtype]())
q_offset
comptime q_offset = SIMD(0)
tmem_addr_byte_offset
comptime tmem_addr_byte_offset = (Depth512AttentionSMem[config].mbar_byte_offset + Depth512AttentionSMem[config].mbar_bytes)
total_mbars
comptime total_mbars = (Depth512AttentionSMem[config].num_fixed_mbars + Depth512AttentionSMem[config].num_kv_mbars)
Methods
__init__
__init__() -> Self
Obtain the base pointer from the kernel's dynamic shared memory.
q_smem
q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the Q region (offset 0).
Returns:
o_smem
o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
Same physical memory as Q, bitcast to the output element type.
Returns:
p_smem
p_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the P buffer region for SS MMA P@V.
Returns:
kv_smem_base
kv_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the KV pipeline region (stage 0).
Each stage holds (BN//2)*BK0 elements of qkv_dtype per SM. Interpreted as K (BN//2 × BK0) during Q@K' or V (BK1 × ov_depth//2) during P@V. The pair-CTA MMA reads from both SMs.
Returns:
correction_smem
correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the correction region (BM Float32 elements).
Returns:
mbar_base
mbar_base(self) -> MBarType
Base of the barrier region.
Layout: 10 fixed barriers followed by 2*num_kv_stages KV pipeline barriers. The WS4 barrier struct wraps this pointer with named accessors for each barrier role.
Returns:
MBarType
tmem_addr_ptr
tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Pointer to the single UInt32 storing the TMEM address.
Returns:
smem_size
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!