For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
SM100AttentionSMem
struct SM100AttentionSMem[qkv_dtype: DType, rope_dtype: DType, scale_dtype: DType, //, config: FA4Config[qkv_dtype, rope_dtype=rope_dtype, scale_dtype=scale_dtype], *, use_order_barriers: Bool = EnableForcedOrdering]
Shared memory layout manager for SM100 Flash Attention kernels.
Stores a base pointer into dynamic shared memory and provides accessor methods for each region (Q, K, V, correction, mbarriers, tmem address). All byte-offset arithmetic is comptime so the accessors compile down to a single pointer add + bitcast.
Parametersβ
- βqkv_dtype (
DType): Element type of Q/K/V data in shared memory. - βrope_dtype (
DType): Element type of Q and K rope. - βscale_dtype (
DType): Element type of the per-token scale used for Q and K. - βconfig (
FA4Config[qkv_dtype, rope_dtype=rope_dtype, scale_dtype=scale_dtype]): FA4 configuration (tile sizes, depths, staging counts, etc.). - βuse_order_barriers (
Bool): Whether forced-ordering barriers are allocated.
Fieldsβ
- βbase (
UnsafePointer[UInt8, MutAnyOrigin, address_space=AddressSpace.SHARED]):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
correction_byte_offsetβ
comptime correction_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth))) + Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))
correction_bytesβ
comptime correction_bytes = (Int((mul Int(2) if (eq config.num_qo, 1) else Int(1), config.BM)) * size_of[DType.float32]())
correction_offsetβ
comptime correction_offset = SIMD((Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) // size_of[DType.float32]()))
k_scale_byte_offsetβ
comptime k_scale_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), (mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) + Int((mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM)))
k_scale_bytesβ
comptime k_scale_bytes = (config.num_k_scale_bufs() * Int((mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN)))
k_scale_stride_bytesβ
comptime k_scale_stride_bytes = (config * size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0))
k_stage_bytesβ
comptime k_stage_bytes = (Int((mul config.v_cols_per_cta(), config.BN)) * size_of[qkv_dtype]()) if config.use_fused_kv else (Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) + Int((mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0))))
k_total_bytesβ
comptime k_total_bytes = (config * Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))))
kv_byte_offsetβ
comptime kv_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_bytes
kv_bytesβ
comptime kv_bytes = (Int((mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages)) + Int((mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN))) if config.use_fused_kv else SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_stages_bytes
kv_offsetβ
comptime kv_offset = SIMD((Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth))) // size_of[qkv_dtype]()))
kv_stages_bytesβ
comptime kv_stages_bytes = (config * Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN), Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))))))
mbar_byte_offsetβ
comptime mbar_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), (mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM), (mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) + Int((mul config.num_k_scale_bufs(), size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN)))
mbar_bytesβ
comptime mbar_bytes = (SIMD(FA4MiscMBars.num_mbars()) * size_of[SharedMemBarrier]())
mbar_offsetβ
comptime mbar_offset = SIMD((Int((add (mul config.num_k_scale_bufs(), size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), (mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM), (mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) // size_of[SharedMemBarrier]()))
MiscMBarsTypeβ
comptime MiscMBarsType = FA4MiscMBars[num_qk_stages=config.num_qk_stages, num_pv_stages=config.num_pv_stages, num_kv_stages=config.num_kv_stages, use_order_barriers=use_order_barriers, use_fused_kv=config.use_fused_kv, pair_cta=config.pair_cta, num_qo=config.num_qo]
num_k_scale_bufsβ
comptime num_k_scale_bufs = config.num_k_scale_bufs()
num_rope_bufsβ
comptime num_rope_bufs = config.num_rope_buffers()
q_byte_offsetβ
comptime q_byte_offset = Int(0)
q_bytesβ
comptime q_bytes = (Int((mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth)) + Int((mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM)))
q_nope_bytesβ
comptime q_nope_bytes = (Int((mul config.BM, config.padded_ov_depth)) * size_of[qkv_dtype]())
q_offsetβ
comptime q_offset = Int32(0)
q_rope_byte_offsetβ
comptime q_rope_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].q_nope_bytes
q_rope_bytesβ
comptime q_rope_bytes = (Int((mul config.rope_depth(), config.BM)) * size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0))
q_scale_byte_offsetβ
comptime q_scale_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) + Int((mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM)))
q_scale_bytesβ
comptime q_scale_bytes = (config * size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0))
rope_byte_offsetβ
comptime rope_byte_offset = (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth))) + Int((mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages)))
rope_bytesβ
comptime rope_bytes = (Int((mul config.num_rope_buffers(), config.rope_depth(), config.BN)) * size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0))
rope_depthβ
comptime rope_depth = config.rope_depth()
rope_dt_sizeβ
comptime rope_dt_size = size_of[rope_dtype]() if (rope_dtype != DType.invalid) else Int(0)
rope_stage_elemsβ
comptime rope_stage_elems = (config * config.rope_depth())
tmem_addr_byte_offsetβ
comptime tmem_addr_byte_offset = (Int((add (mul config.num_k_scale_bufs(), size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth), (mul size_of[DType.float32](), Int(2) if (eq config.num_qo, 1) else Int(1), config.BM), (mul size_of[scale_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> scale_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), Int((add (mul config.num_rope_buffers(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BN), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))) if config.use_fused_kv else Int((add (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN, config.num_kv_stages), (mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages))))) + Int((mul SIMD(FA4MiscMBars.num_mbars()), size_of[SharedMemBarrier]())))
v_byte_offsetβ
comptime v_byte_offset = SM100AttentionSMem[config, use_order_barriers=use_order_barriers].kv_byte_offset if config.use_fused_kv else (Int((add (mul config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0), config.BM), (mul size_of[qkv_dtype](), config.BM, config.padded_ov_depth))) + Int((mul Int((mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN)) if config.use_fused_kv else Int((add (mul config.k_rows_per_cta(), config.rope_depth(), size_of[rope_dtype]() if (xor (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> rope_dtype, "_mlir_value">>, 0), True) else Int(0)), (mul config.v_cols_per_cta(), size_of[qkv_dtype](), config.BN))), config.num_kv_stages)))
v_stage_bytesβ
comptime v_stage_bytes = (Int((mul config.v_cols_per_cta(), config.BN)) * size_of[qkv_dtype]())
Methodsβ
__init__β
def __init__() -> Self
Obtain the base pointer from the kernel's dynamic shared memory.
misc_mbarsβ
def misc_mbars(self) -> Self.MiscMBarsType
Return the FA4MiscMBars wrapper over the mbarrier region.
Returns:
Self.MiscMBarsType
q_smemβ
def q_smem(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the Q region (offset 0).
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
q_rope_smemβ
def q_rope_smem(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the Q rope region (after Q nope in smem).
Returns:
UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
o_smemβ
def o_smem[output_type: DType](self) -> UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
Same physical memory as Q, bitcast to the output element type.
Returns:
UnsafePointer[Scalar[output_type], MutAnyOrigin, address_space=AddressSpace.SHARED]
k_smem_baseβ
def k_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the K region (first stage, offset = kv_byte_offset).
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
v_smem_baseβ
def v_smem_base(self) -> UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the V region (stage 0).
Split mode: V stage 0 starts after all K stages at kv_byte_offset + num_kv_stages * padded_qk_depth * BN * sizeof. Fused mode: Returns the same pointer as k_smem_base() since K_nope and V share the same buffer.
Returns:
UnsafePointer[Scalar[qkv_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
rope_smem_baseβ
def rope_smem_base(self) -> UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the rope region (fused mode only).
Returns:
UnsafePointer[Scalar[rope_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
correction_smemβ
def correction_smem(self) -> UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the correction region (BM Float32 elements).
Returns:
UnsafePointer[Float32, MutAnyOrigin, address_space=AddressSpace.SHARED]
q_scale_smemβ
def q_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the q_scale region (BM elements).
Returns:
UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
k_scale_smemβ
def k_scale_smem(self) -> UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
Base of the k_scale region (num_k_scale_bufs * BN elements).
Returns:
UnsafePointer[Scalar[scale_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]
tmem_addr_ptrβ
def tmem_addr_ptr(self) -> UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
Pointer to the single UInt32 storing the TMEM address.
Returns:
UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]
smem_sizeβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!