IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

Depth512SM100Config

struct Depth512SM100Config[qkv_dtype: DType, *, rope_dtype: DType = DType.invalid, scale_dtype: DType = DType.invalid]

Fields​

  • ​MMA_M (Int):
  • ​BM (Int):
  • ​num_qk_stages (Int):
  • ​split_o (Bool):
  • ​v_cols_per_cta (Int):
  • ​BN (Int):
  • ​BK0 (Int):
  • ​BK1 (Int):
  • ​qk_depth (Int):
  • ​ov_depth (Int):
  • ​group (Int):
  • ​num_q_heads (Int):
  • ​num_kv_heads (Int):
  • ​TMEM_O (Int):
  • ​TMEM_O_hi (Int):
  • ​TMEM_S_even (Int):
  • ​TMEM_S_odd (Int):
  • ​tmem_used (Int):
  • ​fuse_gqa (Bool):
  • ​num_kv_stages (Int):
  • ​smem_used (Int):
  • ​swizzle_mode (TensorMapSwizzle):
  • ​p_buf_bytes (Int):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

cta_group​

comptime cta_group = 2

mbar_size​

comptime mbar_size = size_of[DType.int64]()

MMA_K​

comptime MMA_K = 16 if qkv_dtype.is_half_float() else 32

num_pv_stages​

comptime num_pv_stages = 2

num_threads​

comptime num_threads = (3 * WARPGROUP_SIZE)

qkv_dtype_size​

comptime qkv_dtype_size = size_of[qkv_dtype]()

rope_dtype_size​

comptime rope_dtype_size = size_of[rope_dtype]()

scale_dtype_size​

comptime scale_dtype_size = size_of[scale_dtype]()

sm100_smem_carveout​

comptime sm100_smem_carveout = (B200 - 1024)

sm100_tmem_cols​

comptime sm100_tmem_cols = 512

Methods​

__init__​

def __init__(*, num_q_heads: Int, group: Int, qk_depth: Int, ov_depth: Int, swizzle_mode: TensorMapSwizzle, page_size: Int) -> Self

BM_eff​

def BM_eff(self) -> Int

Number of distinct sequence positions per CTA tile.

When fuse_gqa, each CTA tile covers BM // group seq positions Γ— group heads = BM physical rows.

Returns:

Int

rope_depth​

def rope_depth(self) -> Int

Returns:

Int

num_qo​

def num_qo(self) -> Int

Returns:

Int

correction_smem_elements​

def correction_smem_elements(self) -> Int

Returns:

Int

num_active_warps_per_group​

def num_active_warps_per_group(self) -> Int

Returns:

Int

num_active_threads_per_group​

def num_active_threads_per_group(self) -> Int

Returns:

Int

supported​

def supported(self) -> Bool

Returns:

Bool

description​

def description(self) -> String

Returns:

String