Skip to main content

Mojo struct

Depth512AttentionSMem

struct Depth512AttentionSMem[qkv_dtype: DType, //, config: Depth512SM100Config[qkv_dtype]]

Shared memory layout manager for depth=512 pair-CTA attention kernels.

Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, P, KV pipeline, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.

Parameters​

  • ​qkv_dtype (DType): Element type of Q/K/V data in shared memory.
  • ​config (Depth512SM100Config[qkv_dtype]): Depth512SM100 configuration (tile sizes, depths, staging counts, etc.). All fields are comptime-accessible when the config is a comptime parameter.

Fields​

  • ​base (UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

correction_byte_offset​

comptime correction_byte_offset = (Depth512AttentionSMem[config].kv_byte_offset + Depth512AttentionSMem[config].kv_total_bytes)

correction_bytes​

comptime correction_bytes = (config * size_of[DType.float32]())

correction_offset​

comptime correction_offset = SIMD((Depth512AttentionSMem[config].correction_byte_offset // size_of[DType.float32]()))

kv_byte_offset​

comptime kv_byte_offset = (Depth512AttentionSMem[config].p_byte_offset + Depth512AttentionSMem[config].p_bytes)

kv_offset​

comptime kv_offset = SIMD((Depth512AttentionSMem[config].kv_byte_offset // size_of[qkv_dtype]()))

kv_stage_bytes​

comptime kv_stage_bytes = (((config // 2) * config) * size_of[qkv_dtype]())

kv_total_bytes​

comptime kv_total_bytes = (config * Depth512AttentionSMem[config].kv_stage_bytes)

mbar_byte_offset​

comptime mbar_byte_offset = (Depth512AttentionSMem[config].correction_byte_offset + Depth512AttentionSMem[config].correction_bytes)

mbar_bytes​

comptime mbar_bytes = (Depth512AttentionSMem[config].total_mbars * size_of[SharedMemBarrier]())

mbar_offset​

comptime mbar_offset = SIMD((Depth512AttentionSMem[config].mbar_byte_offset // size_of[SharedMemBarrier]()))

num_fixed_mbars​

comptime num_fixed_mbars = 10 if config.split_o else 8

num_kv_mbars​

comptime num_kv_mbars = (2 * config)

p_byte_offset​

comptime p_byte_offset = Depth512AttentionSMem[config].q_bytes

p_bytes​

comptime p_bytes = ((config * config) * size_of[qkv_dtype]())

p_offset​

comptime p_offset = SIMD((Depth512AttentionSMem[config].p_byte_offset // size_of[qkv_dtype]()))

q_byte_offset​

comptime q_byte_offset = 0

q_bytes​

comptime q_bytes = ((config * config) * size_of[qkv_dtype]())

q_offset​

comptime q_offset = Int32(0)

tmem_addr_byte_offset​

comptime tmem_addr_byte_offset = (Depth512AttentionSMem[config].mbar_byte_offset + Depth512AttentionSMem[config].mbar_bytes)

total_mbars​

comptime total_mbars = (Depth512AttentionSMem[config].num_fixed_mbars + Depth512AttentionSMem[config].num_kv_mbars)

Methods​

__init__​

__init__() -> Self

Obtain the base pointer from the kernel's dynamic shared memory.

q_smem​

q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q region (offset 0).

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

o_smem​

o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

Same physical memory as Q, bitcast to the output element type.

Returns:

UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

p_smem​

p_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the P buffer region for SS MMA P@V.

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

kv_smem_base​

kv_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the KV pipeline region (stage 0).

Each stage holds (BN//2)*BK0 elements of qkv_dtype per SM. Interpreted as K (BN//2 Γ— BK0) during Q@K' or V (BK1 Γ— ov_depth//2) during P@V. The pair-CTA MMA reads from both SMs.

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

correction_smem​

correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the correction region (BM Float32 elements).

Returns:

UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

mbar_base​

mbar_base(self) -> MBarType

Base of the barrier region.

Layout: num_fixed_mbars fixed barriers (10 when split_o, 8 otherwise) followed by 2*num_kv_stages KV pipeline barriers. The Depth512MBars struct wraps this pointer with named accessors for each barrier role.

Returns:

MBarType

tmem_addr_ptr​

tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Pointer to the single UInt32 storing the TMEM address.

Returns:

UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

smem_size​

static smem_size() -> Int

Total dynamic shared memory bytes required.

Returns:

Int