IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SM100AttentionSMem

struct SM100AttentionSMem[qkv_dtype: DType, rope_dtype: DType, scale_dtype: DType, //, config: FA4Config[qkv_dtype, rope_dtype=rope_dtype, scale_dtype=scale_dtype], *, use_order_barriers: Bool = EnableForcedOrdering]

Shared memory layout manager for SM100 Flash Attention kernels.

Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, K, V, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.

Parameters​

  • ​qkv_dtype (DType): Element type of Q/K/V data in shared memory.
  • ​rope_dtype (DType): Element type of Q and K rope.
  • ​scale_dtype (DType): Element type of the per-token scale used for Q and K.
  • ​config (FA4Config[qkv_dtype, rope_dtype=rope_dtype, scale_dtype=scale_dtype]): FA4 configuration (tile sizes, depths, staging counts, etc.).
  • ​use_order_barriers (Bool): Whether forced-ordering barriers are allocated.

Fields​

  • ​base (UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

correction_byte_offset​

comptime correction_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth))) + Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))

correction_bytes​

comptime correction_bytes = (Int((mul Int(2) if (eq config.num_qo, 1) else Int(1), config.BM)) * size_of[DType.float32]())

correction_offset​

comptime correction_offset = SIMD((Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) // size_of[DType.float32]()))

k_scale_byte_offset​

comptime k_scale_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), (mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) + Int((mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM)))

k_scale_bytes​

comptime k_scale_bytes = (config.num_k_scale_bufs() * Int((mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN)))

k_scale_stride_bytes​

comptime k_scale_stride_bytes = (config * size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0))

k_stage_bytes​

comptime k_stage_bytes = (Int((mul config.v_cols_per_cta(), config.BN)) * size_of[qkv_dtype]()) if config.use_fused_kv else (Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) + Int((mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0))))

k_total_bytes​

comptime k_total_bytes = (config * Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))))

kv_byte_offset​

comptime kv_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_bytes

kv_bytes​

comptime kv_bytes = (Int((mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages)) + Int((mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN))) if config.use_fused_kv else SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_stages_bytes

kv_offset​

comptime kv_offset = SIMD((Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth))) // size_of[qkv_dtype]()))

kv_stages_bytes​

comptime kv_stages_bytes = (config * Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN), Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))))))

mbar_byte_offset​

comptime mbar_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), (mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM), (mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) + Int((mul config.num_k_scale_bufs(), size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN)))

mbar_bytes​

comptime mbar_bytes = (SIMD(FA4MiscMBars.num_mbars()) * size_of[SharedMemBarrier]())

mbar_offset​

comptime mbar_offset = SIMD((Int((add (mul config.num_k_scale_bufs(), size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), (mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM), (mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) // size_of[SharedMemBarrier]()))

MiscMBarsType​

comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=config.use_fused_kv, pair_cta=config.pair_cta, num_qo=config.num_qo]

num_k_scale_bufs​

comptime num_k_scale_bufs = config.num_k_scale_bufs()

num_rope_bufs​

comptime num_rope_bufs = config.num_rope_buffers()

q_byte_offset​

comptime q_byte_offset = Int(0)

q_bytes​

comptime q_bytes = (Int((mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth)) + Int((mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM)))

q_nope_bytes​

comptime q_nope_bytes = (Int((mul config.BM, config.padded_ov_depth)) * size_of[qkv_dtype]())

q_offset​

comptime q_offset = Int32(0)

q_rope_byte_offset​

comptime q_rope_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_nope_bytes

q_rope_bytes​

comptime q_rope_bytes = (Int((mul config.rope_depth(), config.BM)) * size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0))

q_scale_byte_offset​

comptime q_scale_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) + Int((mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM)))

q_scale_bytes​

comptime q_scale_bytes = (config * size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0))

rope_byte_offset​

comptime rope_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth))) + Int((mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages)))

rope_bytes​

comptime rope_bytes = (Int((mul config.num_rope_buffers(), config.rope_depth(), config.BN)) * size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0))

rope_depth​

comptime rope_depth = config.rope_depth()

rope_dt_size​

comptime rope_dt_size = size_of[rope_dtype]() if (rope_dtype != DType.invalid) else Int(0)

rope_stage_elems​

comptime rope_stage_elems = (config * config.rope_depth())

tmem_addr_byte_offset​

comptime tmem_addr_byte_offset = (Int((add (mul config.num_k_scale_bufs(), size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), (mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM), (mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) + Int((mul SIMD(FA4MiscMBars.num_mbars()), size_of[SharedMemBarrier]())))

v_byte_offset​

comptime v_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset if config.use_fused_kv else (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth))) + Int((mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages)))

v_stage_bytes​

comptime v_stage_bytes = (Int((mul config.v_cols_per_cta(), config.BN)) * size_of[qkv_dtype]())

Methods​

__init__​

def __init__() -> Self

Obtain the base pointer from the kernel's dynamic shared memory.

misc_mbars​

def misc_mbars(self) -> Self.MiscMBarsType

Return the FA4MiscMBars wrapper over the mbarrier region.

Returns:

Self.MiscMBarsType

q_smem​

def q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q region (offset 0).

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

q_rope_smem​

def q_rope_smem(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the Q rope region (after Q nope in smem).

Returns:

UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

o_smem​

def o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

Same physical memory as Q, bitcast to the output element type.

Returns:

UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]

k_smem_base​

def k_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the K region (first stage, offset = kv_byte_offset).

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

v_smem_base​

def v_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the V region (stage 0).

Split mode: V stage 0 starts after all K stages at kv_byte_offset + num_kv_stages * padded_qk_depth * BN * sizeof. Fused mode: Returns the same pointer as k_smem_base() since K_nope and V share the same buffer.

Returns:

UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

rope_smem_base​

def rope_smem_base(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the rope region (fused mode only).

Returns:

UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

correction_smem​

def correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the correction region (BM Float32 elements).

Returns:

UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]

q_scale_smem​

def q_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the q_scale region (BM elements).

Returns:

UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

k_scale_smem​

def k_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

Base of the k_scale region (num_k_scale_bufs * BN elements).

Returns:

UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]

tmem_addr_ptr​

def tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

Pointer to the single UInt32 storing the TMEM address.

Returns:

UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]

smem_size​

static def smem_size() -> Int

Total dynamic shared memory bytes required.

Returns:

Int