Skip to main content

Mojo struct

SM100AttentionSMem

struct SM100AttentionSMem[qkv_dtype: DType, rope_dtype: DType, scale_dtype: DType, //, config: FA4Config[qkv_dtype, rope_dtype=rope_dtype, scale_dtype=scale_dtype], *, use_order_barriers: Bool = EnableForcedOrdering]

Shared memory layout manager for SM100 Flash Attention kernels.

Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, K, V, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.

Parameters

  • qkv_dtype (DType): Element type of Q/K/V data in shared memory.
  • rope_dtype (DType): Element type of Q and K rope.
  • scale_dtype (DType): Element type of the per-token scale used for Q and K.
  • config (FA4Config): FA4 configuration (tile sizes, depths, staging counts, etc.).
  • use_order_barriers (Bool): Whether forced-ordering barriers are allocated.

Fields

  • base (UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

correction_byte_offset

comptime correction_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_bytes)

correction_bytes

comptime correction_bytes = (config * size_of[DType.float32]())

correction_offset

comptime correction_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_byte_offset // size_of[DType.float32]()))

k_scale_byte_offset

comptime k_scale_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_scale_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_scale_bytes)

k_scale_bytes

comptime k_scale_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].num_k_scale_bufs * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_stride_bytes)

k_scale_stride_bytes

comptime k_scale_stride_bytes = (config * size_of[scale_dtype]() if (scale_dtype != DType.invalid) else 0)

k_stage_bytes

comptime k_stage_bytes = ((config * config) * size_of[qkv_dtype]()) if config.use_fused_kv else (((config * config) * size_of[qkv_dtype]()) + ((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_depth * config) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size))

k_total_bytes

comptime k_total_bytes = (config * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_stage_bytes)

kv_byte_offset

comptime kv_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_bytes

kv_bytes

comptime kv_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_bytes) if config.use_fused_kv else SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_stages_bytes

kv_offset

comptime kv_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset // size_of[qkv_dtype]()))

kv_stages_bytes

comptime kv_stages_bytes = (config * (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_stage_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].v_stage_bytes))

mbar_byte_offset

comptime mbar_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_bytes)

mbar_bytes

comptime mbar_bytes = (Int.__init__[UInt32](FA4MiscMBars.num_mbars()) * size_of[SharedMemBarrier]())

mbar_offset

comptime mbar_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_byte_offset // size_of[SharedMemBarrier]()))

MiscMBarsType

comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=config.use_fused_kv]

num_k_scale_bufs

comptime num_k_scale_bufs = config.num_k_scale_bufs()

num_rope_bufs

comptime num_rope_bufs = config.num_rope_buffers()

q_byte_offset

comptime q_byte_offset = 0

q_bytes

comptime q_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_nope_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_rope_bytes)

q_nope_bytes

comptime q_nope_bytes = ((config * config) * size_of[qkv_dtype]())

q_offset

comptime q_offset = 0

q_rope_byte_offset

comptime q_rope_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_nope_bytes

q_rope_bytes

comptime q_rope_bytes = ((config * config.rope_depth()) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size)

q_scale_byte_offset

comptime q_scale_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_bytes)

q_scale_bytes

comptime q_scale_bytes = (config * size_of[scale_dtype]() if (scale_dtype != DType.invalid) else 0)

rope_byte_offset

comptime rope_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes)

rope_bytes

comptime rope_bytes = ((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].num_rope_bufs * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_stage_elems) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size)

rope_depth

comptime rope_depth = config.rope_depth()

rope_dt_size

comptime rope_dt_size = size_of[rope_dtype]() if (rope_dtype != DType.invalid) else 0

rope_stage_elems

comptime rope_stage_elems = (config * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_depth)

tmem_addr_byte_offset

comptime tmem_addr_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_bytes)

v_byte_offset

comptime v_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset if config.use_fused_kv else (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes)

v_stage_bytes

comptime v_stage_bytes = ((config * config) * size_of[qkv_dtype]())

Methods

__init__

__init__() -> Self

Obtain the base pointer from the kernel's dynamic shared memory.

misc_mbars

misc_mbars(self) -> SM100AttentionSMem[config, use_order_barriers=use_order_barriers].MiscMBarsType

Return the FA4MiscMBars wrapper over the mbarrier region.

Returns:

SM100AttentionSMem

q_smem

q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q region (offset 0).

Returns:

UnsafePointer

q_rope_smem

q_rope_smem(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q rope region (after Q nope in smem).

Returns:

UnsafePointer

o_smem

o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

Same physical memory as Q, bitcast to the output element type.

Returns:

UnsafePointer

k_smem_base

k_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the K region (first stage, offset = kv_byte_offset).

Returns:

UnsafePointer

v_smem_base

v_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the V region (stage 0).

Split mode: V stage 0 starts after all K stages at kv_byte_offset + num_kv_stages * padded_qk_depth * BN * sizeof. Fused mode: Returns the same pointer as k_smem_base() since K_nope and V share the same buffer.

Returns:

UnsafePointer

rope_smem_base

rope_smem_base(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the rope region (fused mode only).

Returns:

UnsafePointer

correction_smem

correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the correction region (BM Float32 elements).

Returns:

UnsafePointer

q_scale_smem

q_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the q_scale region (BM elements).

Returns:

UnsafePointer

k_scale_smem

k_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the k_scale region (num_k_scale_bufs * BN elements).

Returns:

UnsafePointer

tmem_addr_ptr

tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Pointer to the single UInt32 storing the TMEM address.

Returns:

UnsafePointer

smem_size

static smem_size() -> Int

Total dynamic shared memory bytes required.

Returns:

Int

Was this page helpful?