Skip to main content

Mojo struct

Depth512SM100Config

struct Depth512SM100Config[qkv_dtype: DType, *, rope_dtype: DType = DType.invalid, scale_dtype: DType = DType.invalid]

Fields​

  • ​MMA_M (Int):
  • ​BM (Int):
  • ​num_qk_stages (Int):
  • ​split_o (Bool):
  • ​v_cols_per_cta (Int):
  • ​BN (Int):
  • ​BK0 (Int):
  • ​BK1 (Int):
  • ​qk_depth (Int):
  • ​ov_depth (Int):
  • ​group (Int):
  • ​num_q_heads (Int):
  • ​num_kv_heads (Int):
  • ​TMEM_O (Int):
  • ​TMEM_O_hi (Int):
  • ​TMEM_S_even (Int):
  • ​TMEM_S_odd (Int):
  • ​tmem_used (Int):
  • ​fuse_gqa (Bool):
  • ​num_kv_stages (Int):
  • ​smem_used (Int):
  • ​swizzle_mode (TensorMapSwizzle):
  • ​p_buf_bytes (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

cta_group​

comptime cta_group = 2

mbar_size​

comptime mbar_size = size_of[DType.int64]()

MMA_K​

comptime MMA_K = 16 if qkv_dtype.is_half_float() else 32

num_pv_stages​

comptime num_pv_stages = 2

num_threads​

comptime num_threads = 384

qkv_dtype_size​

comptime qkv_dtype_size = size_of[qkv_dtype]()

rope_dtype_size​

comptime rope_dtype_size = size_of[rope_dtype]()

scale_dtype_size​

comptime scale_dtype_size = size_of[scale_dtype]()

sm100_smem_carveout​

comptime sm100_smem_carveout = (B200 - 1024)

sm100_tmem_cols​

comptime sm100_tmem_cols = 512

Methods​

__init__​

__init__(*, num_q_heads: Int, group: Int, qk_depth: Int, ov_depth: Int, swizzle_mode: TensorMapSwizzle, page_size: Int) -> Self

BM_eff​

BM_eff(self) -> Int

Number of distinct sequence positions per CTA tile.

When fuse_gqa, each CTA tile covers BM // group seq positions Γ— group heads = BM physical rows.

Returns:

Int

rope_depth​

rope_depth(self) -> Int

Returns:

Int

num_qo​

num_qo(self) -> Int

Returns:

Int

correction_smem_elements​

correction_smem_elements(self) -> Int

Returns:

Int

num_active_warps_per_group​

num_active_warps_per_group(self) -> Int

Returns:

Int

num_active_threads_per_group​

num_active_threads_per_group(self) -> Int

Returns:

Int

supported​

supported(self) -> Bool

Returns:

Bool

description​

description(self) -> String

Returns:

String