IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SM100AttentionSMem

struct SM100AttentionSMem[qkv_dtype: DType, rope_dtype: DType, scale_dtype: DType, //, config: FA4Config[qkv_dtype, rope_dtype=rope_dtype, scale_dtype=scale_dtype], *, use_order_barriers: Bool = EnableForcedOrdering]

Shared memory layout manager for SM100 Flash Attention kernels.

Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, K, V, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.

Parameters​

  • ​qkv_dtype (DType): Element type of Q/K/V data in shared memory.
  • ​rope_dtype (DType): Element type of Q and K rope.
  • ​scale_dtype (DType): Element type of the per-token scale used for Q and K.
  • ​config (FA4Config[qkv_dtype, rope_dtype=rope_dtype, scale_dtype=scale_dtype]): FA4 configuration (tile sizes, depths, staging counts, etc.).
  • ​use_order_barriers (Bool): Whether forced-ordering barriers are allocated.

Fields​

  • ​base (UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

correction_byte_offset​

comptime correction_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_bytes)

correction_bytes​

comptime correction_bytes = ((2 if (config == 1) else 1 * config) * size_of[DType.float32]())

correction_offset​

comptime correction_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_byte_offset // size_of[DType.float32]()))

k_scale_byte_offset​

comptime k_scale_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_scale_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_scale_bytes)

k_scale_bytes​

comptime k_scale_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].num_k_scale_bufs * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_stride_bytes)

k_scale_stride_bytes​

comptime k_scale_stride_bytes = (config * size_of[scale_dtype]() if (scale_dtype != DType.invalid) else 0)

k_stage_bytes​

comptime k_stage_bytes = ((config.v_cols_per_cta() * config) * size_of[qkv_dtype]()) if config.use_fused_kv else (((config.v_cols_per_cta() * config) * size_of[qkv_dtype]()) + ((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_depth * config.k_rows_per_cta()) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size))

k_total_bytes​

comptime k_total_bytes = (config * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_stage_bytes)

kv_byte_offset​

comptime kv_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_bytes

kv_bytes​

comptime kv_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_bytes) if config.use_fused_kv else SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_stages_bytes

kv_offset​

comptime kv_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset // size_of[qkv_dtype]()))

kv_stages_bytes​

comptime kv_stages_bytes = (config * (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_stage_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].v_stage_bytes))

mbar_byte_offset​

comptime mbar_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_bytes)

mbar_bytes​

comptime mbar_bytes = (Int[UInt32](FA4MiscMBars.num_mbars()) * size_of[SharedMemBarrier]())

mbar_offset​

comptime mbar_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_byte_offset // size_of[SharedMemBarrier]()))

MiscMBarsType​

comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=config.use_fused_kv, pair_cta=config.pair_cta, num_qo=config.num_qo]

num_k_scale_bufs​

comptime num_k_scale_bufs = config.num_k_scale_bufs()

num_rope_bufs​

comptime num_rope_bufs = config.num_rope_buffers()

q_byte_offset​

comptime q_byte_offset = 0

q_bytes​

comptime q_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_nope_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_rope_bytes)

q_nope_bytes​

comptime q_nope_bytes = ((config * config) * size_of[qkv_dtype]())

q_offset​

comptime q_offset = Int32(0)

q_rope_byte_offset​

comptime q_rope_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_nope_bytes

q_rope_bytes​

comptime q_rope_bytes = ((config * config.rope_depth()) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size)

q_scale_byte_offset​

comptime q_scale_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_bytes)

q_scale_bytes​

comptime q_scale_bytes = (config * size_of[scale_dtype]() if (scale_dtype != DType.invalid) else 0)

rope_byte_offset​

comptime rope_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes)

rope_bytes​

comptime rope_bytes = ((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].num_rope_bufs * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_stage_elems) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size)

rope_depth​

comptime rope_depth = config.rope_depth()

rope_dt_size​

comptime rope_dt_size = size_of[rope_dtype]() if (rope_dtype != DType.invalid) else 0

rope_stage_elems​

comptime rope_stage_elems = (config * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_depth)

tmem_addr_byte_offset​

comptime tmem_addr_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_bytes)

v_byte_offset​

comptime v_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset if config.use_fused_kv else (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes)

v_stage_bytes​

comptime v_stage_bytes = ((config.v_cols_per_cta() * config) * size_of[qkv_dtype]())

Methods​

__init__​

def __init__() -> Self

Obtain the base pointer from the kernel's dynamic shared memory.

misc_mbars​

def misc_mbars(self) -> Self.MiscMBarsType

Return the FA4MiscMBars wrapper over the mbarrier region.

Returns:

Self.MiscMBarsType

q_smem​

def q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q region (offset 0).

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

q_rope_smem​

def q_rope_smem(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q rope region (after Q nope in smem).

Returns:

UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

o_smem​

def o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

Same physical memory as Q, bitcast to the output element type.

Returns:

UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

k_smem_base​

def k_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the K region (first stage, offset = kv_byte_offset).

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

v_smem_base​

def v_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the V region (stage 0).

Split mode: V stage 0 starts after all K stages at kv_byte_offset + num_kv_stages * padded_qk_depth * BN * sizeof. Fused mode: Returns the same pointer as k_smem_base() since K_nope and V share the same buffer.

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

rope_smem_base​

def rope_smem_base(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the rope region (fused mode only).

Returns:

UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

correction_smem​

def correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the correction region (BM Float32 elements).

Returns:

UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

q_scale_smem​

def q_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the q_scale region (BM elements).

Returns:

UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

k_scale_smem​

def k_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the k_scale region (num_k_scale_bufs * BN elements).

Returns:

UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

tmem_addr_ptr​

def tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Pointer to the single UInt32 storing the TMEM address.

Returns:

UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

smem_size​

static def smem_size() -> Int

Total dynamic shared memory bytes required.

Returns:

Int