Mojo struct
FA4Config
struct FA4Config[qkv_dtype: DType, *, rope_dtype: DType = DType.invalid, scale_dtype: DType = DType.invalid]
Fields
- MMA_M (
Int): - BM (
Int): - BN (
Int): - BK0 (
Int): - BK1 (
Int): - qk_depth (
Int): - padded_qk_depth (
Int): - ov_depth (
Int): - padded_ov_depth (
Int): - group (
Int): - num_q_heads (
Int): - num_kv_heads (
Int): - TMEM_S1 (
Int): - TMEM_O0 (
Int): - TMEM_O1 (
Int): - TMEM_P0 (
Int): - TMEM_P1 (
Int): - TMEM_C0 (
Int): - TMEM_C1 (
Int): - tmem_used (
Int): - num_kv_stages (
Int): - num_qk_stages (
Int): - num_pv_stages (
Int): - smem_used (
Int): - fuse_gqa (
Bool): - swizzle_mode (
TensorMapSwizzle): - use_fused_kv (
Bool): - pair_cta (
Bool):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
mbar_size
comptime mbar_size = size_of[DType.int64]()
MMA_K
comptime MMA_K = 16 if qkv_dtype.is_half_float() else 32
num_correction_cols
comptime num_correction_cols = 1
num_threads
comptime num_threads = 512
qkv_dtype_size
comptime qkv_dtype_size = size_of[qkv_dtype]()
rope_dtype_size
comptime rope_dtype_size = size_of[rope_dtype]()
scale_dtype_size
comptime scale_dtype_size = size_of[scale_dtype]()
sm100_smem_carveout
comptime sm100_smem_carveout = (B200 - 1024)
sm100_tmem_cols
comptime sm100_tmem_cols = 512
TMEM_S0
comptime TMEM_S0 = 0
Methods
__init__
__init__(*, num_q_heads: Int, group: Int, qk_depth: Int, ov_depth: Int, swizzle_mode: TensorMapSwizzle, page_size: Int, is_mla: Bool, pair_cta: Bool = False) -> Self
BM_eff
BM_eff(self) -> Int
Number of distinct sequence positions per full tile. When fuse_gqa, each tile covers BM // group seq positions x group heads.
Returns:
num_qo
cta_group
PairBM_eff
v_cols_per_cta
k_rows_per_cta
q_nope_bytes
q_rope_bytes
q_rope_bytes(self) -> Int
Q rope region bytes. Uses rope_dtype_size when set, else dtype_size.
Returns:
rope_depth
rope_depth(self) -> Int
Depth of the rope part. Calculated as: padded_qk_depth - padded_ov_depth (0 for MHA where qk_depth == ov_depth).
Returns:
num_rope_buffers
num_rope_buffers(self) -> Int
Number of separate rope smem buffers (fused mode only).
In fused mode K tiles alternate with V tiles in the pipeline. At most ceildiv(num_kv_stages, 2) K tiles can be in-flight simultaneously, so we only need that many rope buffers. For MHA (rope_depth=0), no rope buffers are needed.
Returns:
num_k_scale_bufs
num_k_scale_bufs(self) -> Int
Number of staged k_scale smem buffers.
In fused mode, K tiles alternate with V tiles so at most ceildiv(num_kv_stages, 2) K tiles are in-flight simultaneously. In split mode, each KV stage has its own K buffer. Returns 0 when scale_dtype_size == 0 (no per-token scaling).
Returns:
supported
description
correction_smem_elements
num_active_warps_per_group
num_active_threads_per_group
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!