IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Depth512AttentionSMem

struct Depth512AttentionSMem[qkv_dtype: DType, //, config: Depth512SM100Config[qkv_dtype]]

Shared memory layout manager for depth=512 pair-CTA attention kernels.

Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, P, KV pipeline, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.

Parameters​

  • ​qkv_dtype (DType): Element type of Q/K/V data in shared memory.
  • ​config (Depth512SM100Config[qkv_dtype]): Depth512SM100 configuration (tile sizes, depths, staging counts, etc.). All fields are comptime-accessible when the config is a comptime parameter.

Fields​

  • ​base (UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

correction_byte_offset​

comptime correction_byte_offset = (Int((add (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth))) + Int((mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages)))

correction_bytes​

comptime correction_bytes = (config * size_of[DType.float32]())

correction_offset​

comptime correction_offset = SIMD((Int((add (mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages), (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth))) // size_of[DType.float32]()))

kv_byte_offset​

comptime kv_byte_offset = (Int((mul size_of[qkv_dtype](), config.BM, config.qk_depth)) + Int((mul size_of[qkv_dtype](), config.BM, config.BN)))

kv_offset​

comptime kv_offset = SIMD((Int((add (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth))) // size_of[qkv_dtype]()))

kv_stage_bytes​

comptime kv_stage_bytes = (Int((mul (config // Int(2)), config.BK0)) * size_of[qkv_dtype]())

kv_total_bytes​

comptime kv_total_bytes = (config * Int((mul (config // Int(2)), size_of[qkv_dtype](), config.BK0)))

mbar_byte_offset​

comptime mbar_byte_offset = (Int((add (mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages), (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth))) + Int((mul size_of[DType.float32](), config.BM)))

mbar_bytes​

comptime mbar_bytes = (Int((add (mul config.num_kv_stages, 2), Int(10) if config.split_o else Int(8))) * size_of[SharedMemBarrier]())

mbar_offset​

comptime mbar_offset = SIMD((Int((add (mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages), (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth), (mul size_of[DType.float32](), config.BM))) // size_of[SharedMemBarrier]()))

num_fixed_mbars​

comptime num_fixed_mbars = Int(10) if config.split_o else Int(8)

num_kv_mbars​

comptime num_kv_mbars = (Int(2) * config)

p_byte_offset​

comptime p_byte_offset = Depth512AttentionSMem[config].q_bytes

p_bytes​

comptime p_bytes = (Int((mul config.BM, config.BN)) * size_of[qkv_dtype]())

p_offset​

comptime p_offset = SIMD((Int((mul size_of[qkv_dtype](), config.BM, config.qk_depth)) // size_of[qkv_dtype]()))

q_byte_offset​

comptime q_byte_offset = Int(0)

q_bytes​

comptime q_bytes = (Int((mul config.BM, config.qk_depth)) * size_of[qkv_dtype]())

q_offset​

comptime q_offset = Int32(0)

tmem_addr_byte_offset​

comptime tmem_addr_byte_offset = (Int((add (mul (config // Int(2)), size_of[qkv_dtype](), config.BK0, config.num_kv_stages), (mul size_of[qkv_dtype](), config.BM, config.BN), (mul size_of[qkv_dtype](), config.BM, config.qk_depth), (mul size_of[DType.float32](), config.BM))) + Int((add (mul size_of[SharedMemBarrier](), config.num_kv_stages, 2), (mul size_of[SharedMemBarrier](), Int(10) if config.split_o else Int(8)))))

total_mbars​

comptime total_mbars = (Int(10) if config.split_o else Int(8) + Int((mul config.num_kv_stages, 2)))

Methods​

__init__​

def __init__() -> Self

Obtain the base pointer from the kernel's dynamic shared memory.

q_smem​

def q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q region (offset 0).

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

o_smem​

def o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

Same physical memory as Q, bitcast to the output element type.

Returns:

UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

p_smem​

def p_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the P buffer region for SS MMA P@V.

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

kv_smem_base​

def kv_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the KV pipeline region (stage 0).

Each stage holds (BN//2)*BK0 elements of qkv_dtype per SM. Interpreted as K (BN//2 Γ— BK0) during Q@K' or V (BK1 Γ— ov_depth//2) during P@V. The pair-CTA MMA reads from both SMs.

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

correction_smem​

def correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the correction region (BM Float32 elements).

Returns:

UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

mbar_base​

def mbar_base(self) -> MBarType

Base of the barrier region.

Layout: num_fixed_mbars fixed barriers (10 when split_o, 8 otherwise) followed by 2*num_kv_stages KV pipeline barriers. The Depth512MBars struct wraps this pointer with named accessors for each barrier role.

Returns:

MBarType

tmem_addr_ptr​

def tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Pointer to the single UInt32 storing the TMEM address.

Returns:

UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

smem_size​

static def smem_size() -> Int

Total dynamic shared memory bytes required.

Returns:

Int