Skip to main content

Mojo struct

Depth512AttentionSMem

struct Depth512AttentionSMem[qkv_dtype: DType, //, config: Depth512SM100Config[qkv_dtype]]

Shared memory layout manager for depth=512 pair-CTA attention kernels.

Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, P, KV pipeline, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.

Parameters

  • qkv_dtype (DType): Element type of Q/K/V data in shared memory.
  • config (Depth512SM100Config): Depth512SM100 configuration (tile sizes, depths, staging counts, etc.). All fields are comptime-accessible when the config is a comptime parameter.

Fields

  • base (UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

correction_byte_offset

comptime correction_byte_offset = (Depth512AttentionSMem[config].kv_byte_offset + Depth512AttentionSMem[config].kv_total_bytes)

correction_bytes

comptime correction_bytes = (Depth512SM100Config[qkv_dtype].BM * size_of[DType.float32]())

correction_offset

comptime correction_offset = SIMD((Depth512AttentionSMem[config].correction_byte_offset // size_of[DType.float32]()))

kv_byte_offset

comptime kv_byte_offset = (Depth512AttentionSMem[config].p_byte_offset + Depth512AttentionSMem[config].p_bytes)

kv_offset

comptime kv_offset = SIMD((Depth512AttentionSMem[config].kv_byte_offset // size_of[qkv_dtype]()))

kv_stage_bytes

comptime kv_stage_bytes = (((config // 2) * config) * size_of[qkv_dtype]())

kv_total_bytes

comptime kv_total_bytes = (config * Depth512AttentionSMem[config].kv_stage_bytes)

mbar_byte_offset

comptime mbar_byte_offset = (Depth512AttentionSMem[config].correction_byte_offset + Depth512AttentionSMem[config].correction_bytes)

mbar_bytes

comptime mbar_bytes = (Depth512AttentionSMem[config].total_mbars * size_of[SharedMemBarrier]())

mbar_offset

comptime mbar_offset = SIMD((Depth512AttentionSMem[config].mbar_byte_offset // size_of[SharedMemBarrier]()))

num_fixed_mbars

comptime num_fixed_mbars = 10

num_kv_mbars

comptime num_kv_mbars = (2 * config)

p_byte_offset

comptime p_byte_offset = Depth512AttentionSMem[config].q_bytes

p_bytes

comptime p_bytes = ((Depth512SM100Config[qkv_dtype].BM * config) * size_of[qkv_dtype]())

p_offset

comptime p_offset = SIMD((Depth512AttentionSMem[config].p_byte_offset // size_of[qkv_dtype]()))

q_byte_offset

comptime q_byte_offset = 0

q_bytes

comptime q_bytes = ((Depth512SM100Config[qkv_dtype].BM * config) * size_of[qkv_dtype]())

q_offset

comptime q_offset = SIMD(0)

tmem_addr_byte_offset

comptime tmem_addr_byte_offset = (Depth512AttentionSMem[config].mbar_byte_offset + Depth512AttentionSMem[config].mbar_bytes)

total_mbars

comptime total_mbars = (Depth512AttentionSMem[config].num_fixed_mbars + Depth512AttentionSMem[config].num_kv_mbars)

Methods

__init__

__init__() -> Self

Obtain the base pointer from the kernel's dynamic shared memory.

q_smem

q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q region (offset 0).

Returns:

UnsafePointer

o_smem

o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

Same physical memory as Q, bitcast to the output element type.

Returns:

UnsafePointer

p_smem

p_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the P buffer region for SS MMA P@V.

Returns:

UnsafePointer

kv_smem_base

kv_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the KV pipeline region (stage 0).

Each stage holds (BN//2)*BK0 elements of qkv_dtype per SM. Interpreted as K (BN//2 × BK0) during Q@K' or V (BK1 × ov_depth//2) during P@V. The pair-CTA MMA reads from both SMs.

Returns:

UnsafePointer

correction_smem

correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the correction region (BM Float32 elements).

Returns:

UnsafePointer

mbar_base

mbar_base(self) -> MBarType

Base of the barrier region.

Layout: 10 fixed barriers followed by 2*num_kv_stages KV pipeline barriers. The WS4 barrier struct wraps this pointer with named accessors for each barrier role.

Returns:

MBarType

tmem_addr_ptr

tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Pointer to the single UInt32 storing the TMEM address.

Returns:

UnsafePointer

smem_size

static smem_size() -> Int

Total dynamic shared memory bytes required.

Returns:

Int

Was this page helpful?