Mojo struct
SM100AttentionSMem
struct SM100AttentionSMem[qkv_dtype: DType, rope_dtype: DType, scale_dtype: DType, //, config: FA4Config[qkv_dtype, rope_dtype=rope_dtype, scale_dtype=scale_dtype], *, use_order_barriers: Bool = EnableForcedOrdering]
Shared memory layout manager for SM100 Flash Attention kernels.
Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, K, V, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.
Parameters
- qkv_dtype (
DType): Element type of Q/K/V data in shared memory. - rope_dtype (
DType): Element type of Q and K rope. - scale_dtype (
DType): Element type of the per-token scale used for Q and K. - config (
FA4Config): FA4 configuration (tile sizes, depths, staging counts, etc.). - use_order_barriers (
Bool): Whether forced-ordering barriers are allocated.
Fields
- base (
UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
correction_byte_offset
comptime correction_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_bytes)
correction_bytes
comptime correction_bytes = (config * size_of[DType.float32]())
correction_offset
comptime correction_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_byte_offset // size_of[DType.float32]()))
k_scale_byte_offset
comptime k_scale_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_scale_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_scale_bytes)
k_scale_bytes
comptime k_scale_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].num_k_scale_bufs * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_stride_bytes)
k_scale_stride_bytes
comptime k_scale_stride_bytes = (config * size_of[scale_dtype]() if (scale_dtype != DType.invalid) else 0)
k_stage_bytes
comptime k_stage_bytes = ((config * config) * size_of[qkv_dtype]()) if config.use_fused_kv else (((config * config) * size_of[qkv_dtype]()) + ((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_depth * config) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size))
k_total_bytes
comptime k_total_bytes = (config * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_stage_bytes)
kv_byte_offset
comptime kv_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_bytes
kv_bytes
comptime kv_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_bytes) if config.use_fused_kv else SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_stages_bytes
kv_offset
comptime kv_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset // size_of[qkv_dtype]()))
kv_stages_bytes
comptime kv_stages_bytes = (config * (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_stage_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].v_stage_bytes))
mbar_byte_offset
comptime mbar_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_scale_bytes)
mbar_bytes
comptime mbar_bytes = (Int.__init__[UInt32](FA4MiscMBars.num_mbars()) * size_of[SharedMemBarrier]())
mbar_offset
comptime mbar_offset = SIMD((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_byte_offset // size_of[SharedMemBarrier]()))
MiscMBarsType
comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=config.use_fused_kv]
num_k_scale_bufs
comptime num_k_scale_bufs = config.num_k_scale_bufs()
num_rope_bufs
comptime num_rope_bufs = config.num_rope_buffers()
q_byte_offset
comptime q_byte_offset = 0
q_bytes
comptime q_bytes = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_nope_bytes + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_rope_bytes)
q_nope_bytes
comptime q_nope_bytes = ((config * config) * size_of[qkv_dtype]())
q_offset
comptime q_offset = 0
q_rope_byte_offset
comptime q_rope_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_nope_bytes
q_rope_bytes
comptime q_rope_bytes = ((config * config.rope_depth()) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size)
q_scale_byte_offset
comptime q_scale_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].correction_bytes)
q_scale_bytes
comptime q_scale_bytes = (config * size_of[scale_dtype]() if (scale_dtype != DType.invalid) else 0)
rope_byte_offset
comptime rope_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes)
rope_bytes
comptime rope_bytes = ((SM100AttentionSMem[config, use_order_barriers=use_order_barriers].num_rope_bufs * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_stage_elems) * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_dt_size)
rope_depth
comptime rope_depth = config.rope_depth()
rope_dt_size
comptime rope_dt_size = size_of[rope_dtype]() if (rope_dtype != DType.invalid) else 0
rope_stage_elems
comptime rope_stage_elems = (config * SM100AttentionSMem[config, use_order_barriers=use_order_barriers].rope_depth)
tmem_addr_byte_offset
comptime tmem_addr_byte_offset = (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].mbar_bytes)
v_byte_offset
comptime v_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset if config.use_fused_kv else (SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset + SM100AttentionSMem[config, use_order_barriers=use_order_barriers].k_total_bytes)
v_stage_bytes
comptime v_stage_bytes = ((config * config) * size_of[qkv_dtype]())
Methods
__init__
__init__() -> Self
Obtain the base pointer from the kernel's dynamic shared memory.
misc_mbars
misc_mbars(self) -> SM100AttentionSMem[config, use_order_barriers=use_order_barriers].MiscMBarsType
Return the FA4MiscMBars wrapper over the mbarrier region.
Returns:
q_smem
q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the Q region (offset 0).
Returns:
q_rope_smem
q_rope_smem(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the Q rope region (after Q nope in smem).
Returns:
o_smem
o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
Same physical memory as Q, bitcast to the output element type.
Returns:
k_smem_base
k_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the K region (first stage, offset = kv_byte_offset).
Returns:
v_smem_base
v_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the V region (stage 0).
Split mode: V stage 0 starts after all K stages at kv_byte_offset + num_kv_stages * padded_qk_depth * BN * sizeof. Fused mode: Returns the same pointer as k_smem_base() since K_nope and V share the same buffer.
Returns:
rope_smem_base
rope_smem_base(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the rope region (fused mode only).
Returns:
correction_smem
correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the correction region (BM Float32 elements).
Returns:
q_scale_smem
q_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the q_scale region (BM elements).
Returns:
k_scale_smem
k_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the k_scale region (num_k_scale_bufs * BN elements).
Returns:
tmem_addr_ptr
tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Pointer to the single UInt32 storing the TMEM address.
Returns:
smem_size
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!