Skip to main content

Mojo struct

FA4Config

struct FA4Config[qkv_dtype: DType, *, rope_dtype: DType = DType.invalid, scale_dtype: DType = DType.invalid]

Fields

  • MMA_M (Int):
  • BM (Int):
  • BN (Int):
  • BK0 (Int):
  • BK1 (Int):
  • qk_depth (Int):
  • padded_qk_depth (Int):
  • ov_depth (Int):
  • padded_ov_depth (Int):
  • group (Int):
  • num_q_heads (Int):
  • num_kv_heads (Int):
  • TMEM_S1 (Int):
  • TMEM_O0 (Int):
  • TMEM_O1 (Int):
  • TMEM_P0 (Int):
  • TMEM_P1 (Int):
  • TMEM_C0 (Int):
  • TMEM_C1 (Int):
  • tmem_used (Int):
  • num_kv_stages (Int):
  • num_qk_stages (Int):
  • num_pv_stages (Int):
  • smem_used (Int):
  • fuse_gqa (Bool):
  • swizzle_mode (TensorMapSwizzle):
  • use_fused_kv (Bool):
  • pair_cta (Bool):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

mbar_size

comptime mbar_size = size_of[DType.int64]()

MMA_K

comptime MMA_K = 16 if qkv_dtype.is_half_float() else 32

num_correction_cols

comptime num_correction_cols = 1

num_threads

comptime num_threads = 512

qkv_dtype_size

comptime qkv_dtype_size = size_of[qkv_dtype]()

rope_dtype_size

comptime rope_dtype_size = size_of[rope_dtype]()

scale_dtype_size

comptime scale_dtype_size = size_of[scale_dtype]()

sm100_smem_carveout

comptime sm100_smem_carveout = (B200 - 1024)

sm100_tmem_cols

comptime sm100_tmem_cols = 512

TMEM_S0

comptime TMEM_S0 = 0

Methods

__init__

__init__(*, num_q_heads: Int, group: Int, qk_depth: Int, ov_depth: Int, swizzle_mode: TensorMapSwizzle, page_size: Int, is_mla: Bool, pair_cta: Bool = False) -> Self

BM_eff

BM_eff(self) -> Int

Number of distinct sequence positions per full tile. When fuse_gqa, each tile covers BM // group seq positions x group heads.

Returns:

Int

num_qo

num_qo(self) -> Int

Returns:

Int

cta_group

cta_group(self) -> Int

Returns:

Int

PairBM_eff

PairBM_eff(self) -> Int

Sequence positions covered by both CTAs in a pair.

Returns:

Int

v_cols_per_cta

v_cols_per_cta(self) -> Int

V columns stored in this CTA's SMEM.

Returns:

Int

k_rows_per_cta

k_rows_per_cta(self) -> Int

K rows stored in this CTA's SMEM.

Returns:

Int

q_nope_bytes

q_nope_bytes(self) -> Int

Q nope region bytes: BM * padded_ov_depth * dtype_size.

Returns:

Int

q_rope_bytes

q_rope_bytes(self) -> Int

Q rope region bytes. Uses rope_dtype_size when set, else dtype_size.

Returns:

Int

rope_depth

rope_depth(self) -> Int

Depth of the rope part. Calculated as: padded_qk_depth - padded_ov_depth (0 for MHA where qk_depth == ov_depth).

Returns:

Int

num_rope_buffers

num_rope_buffers(self) -> Int

Number of separate rope smem buffers (fused mode only).

In fused mode K tiles alternate with V tiles in the pipeline. At most ceildiv(num_kv_stages, 2) K tiles can be in-flight simultaneously, so we only need that many rope buffers. For MHA (rope_depth=0), no rope buffers are needed.

Returns:

Int

num_k_scale_bufs

num_k_scale_bufs(self) -> Int

Number of staged k_scale smem buffers.

In fused mode, K tiles alternate with V tiles so at most ceildiv(num_kv_stages, 2) K tiles are in-flight simultaneously. In split mode, each KV stage has its own K buffer. Returns 0 when scale_dtype_size == 0 (no per-token scaling).

Returns:

Int

supported

supported(self) -> Bool

Returns:

Bool

description

description(self) -> String

Returns:

String

correction_smem_elements

correction_smem_elements(self) -> Int

Returns:

Int

num_active_warps_per_group

num_active_warps_per_group(self) -> Int

Returns:

Int

num_active_threads_per_group

num_active_threads_per_group(self) -> Int

Returns:

Int

Was this page helpful?