Skip to main content

Mojo struct

Depth512SM100Config

struct Depth512SM100Config[qkv_dtype: DType, *, rope_dtype: DType = DType.invalid, scale_dtype: DType = DType.invalid]

Fields

  • BN (Int):
  • BK0 (Int):
  • BK1 (Int):
  • qk_depth (Int):
  • ov_depth (Int):
  • group (Int):
  • num_q_heads (Int):
  • num_kv_heads (Int):
  • TMEM_O (Int):
  • TMEM_O_hi (Int):
  • TMEM_S_even (Int):
  • TMEM_S_odd (Int):
  • tmem_used (Int):
  • num_kv_stages (Int):
  • smem_used (Int):
  • swizzle_mode (TensorMapSwizzle):
  • p_buf_bytes (Int):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

BM

comptime BM = 64

cta_group

comptime cta_group = 2

mbar_size

comptime mbar_size = size_of[DType.int64]()

MMA_K

comptime MMA_K = 16 if qkv_dtype.is_half_float() else 32

MMA_M

comptime MMA_M = 128

num_pv_stages

comptime num_pv_stages = 2

num_qk_stages

comptime num_qk_stages = 4

num_threads

comptime num_threads = 384

qkv_dtype_size

comptime qkv_dtype_size = size_of[qkv_dtype]()

rope_dtype_size

comptime rope_dtype_size = size_of[rope_dtype]()

scale_dtype_size

comptime scale_dtype_size = size_of[scale_dtype]()

sm100_smem_carveout

comptime sm100_smem_carveout = (B200 - 1024)

sm100_tmem_cols

comptime sm100_tmem_cols = 512

Methods

__init__

__init__(*, num_q_heads: Int, group: Int, qk_depth: Int, ov_depth: Int, swizzle_mode: TensorMapSwizzle, page_size: Int) -> Self

rope_depth

rope_depth(self) -> Int

Returns:

Int

num_qo

num_qo(self) -> Int

Returns:

Int

correction_smem_elements

correction_smem_elements(self) -> Int

Returns:

Int

num_active_warps_per_group

num_active_warps_per_group(self) -> Int

Returns:

Int

num_active_threads_per_group

num_active_threads_per_group(self) -> Int

Returns:

Int

supported

supported(self) -> Bool

Returns:

Bool

description

description(self) -> String

Returns:

String

Was this page helpful?