Mojo struct
Depth512SM100Config
struct Depth512SM100Config[qkv_dtype: DType, *, rope_dtype: DType = DType.invalid, scale_dtype: DType = DType.invalid]
Fields
- BN (
Int): - BK0 (
Int): - BK1 (
Int): - qk_depth (
Int): - ov_depth (
Int): - group (
Int): - num_q_heads (
Int): - num_kv_heads (
Int): - TMEM_O (
Int): - TMEM_O_hi (
Int): - TMEM_S_even (
Int): - TMEM_S_odd (
Int): - tmem_used (
Int): - num_kv_stages (
Int): - smem_used (
Int): - swizzle_mode (
TensorMapSwizzle): - p_buf_bytes (
Int):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
BM
comptime BM = 64
cta_group
comptime cta_group = 2
mbar_size
comptime mbar_size = size_of[DType.int64]()
MMA_K
comptime MMA_K = 16 if qkv_dtype.is_half_float() else 32
MMA_M
comptime MMA_M = 128
num_pv_stages
comptime num_pv_stages = 2
num_qk_stages
comptime num_qk_stages = 4
num_threads
comptime num_threads = 384
qkv_dtype_size
comptime qkv_dtype_size = size_of[qkv_dtype]()
rope_dtype_size
comptime rope_dtype_size = size_of[rope_dtype]()
scale_dtype_size
comptime scale_dtype_size = size_of[scale_dtype]()
sm100_smem_carveout
comptime sm100_smem_carveout = (B200 - 1024)
sm100_tmem_cols
comptime sm100_tmem_cols = 512
Methods
__init__
__init__(*, num_q_heads: Int, group: Int, qk_depth: Int, ov_depth: Int, swizzle_mode: TensorMapSwizzle, page_size: Int) -> Self
rope_depth
num_qo
correction_smem_elements
num_active_warps_per_group
num_active_threads_per_group
supported
description
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!