For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
Depth512SM100Config
struct Depth512SM100Config[qkv_dtype: DType, *, rope_dtype: DType = DType.invalid, scale_dtype: DType = DType.invalid]
Fieldsβ
- βMMA_M (
Int): - βBM (
Int): - βnum_qk_stages (
Int): - βsplit_o (
Bool): - βv_cols_per_cta (
Int): - βBN (
Int): - βBK0 (
Int): - βBK1 (
Int): - βqk_depth (
Int): - βov_depth (
Int): - βgroup (
Int): - βnum_q_heads (
Int): - βnum_kv_heads (
Int): - βTMEM_O (
Int): - βTMEM_O_hi (
Int): - βTMEM_S_even (
Int): - βTMEM_S_odd (
Int): - βtmem_used (
Int): - βfuse_gqa (
Bool): - βnum_kv_stages (
Int): - βsmem_used (
Int): - βswizzle_mode (
TensorMapSwizzle): - βp_buf_bytes (
Int):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
cta_groupβ
comptime cta_group = Int(2)
mbar_sizeβ
comptime mbar_size = size_of[DType.int64]()
MMA_Kβ
comptime MMA_K = Int(16) if qkv_dtype.is_half_float() else Int(32)
num_pv_stagesβ
comptime num_pv_stages = Int(2)
num_threadsβ
comptime num_threads = (Int(3) * _resolve_warpgroup_size())
qkv_dtype_sizeβ
comptime qkv_dtype_size = size_of[qkv_dtype]()
rope_dtype_sizeβ
comptime rope_dtype_size = size_of[rope_dtype]()
scale_dtype_sizeβ
comptime scale_dtype_size = size_of[scale_dtype]()
sm100_smem_carveoutβ
comptime sm100_smem_carveout = (GPUInfo.from_family(AcceleratorArchitectureFamily(Int(32), Int(2048), Int(233472), Int(65536), Int(1024)), StringSlice("B200"), Vendor(Int8(2)), StringSlice("cuda"), StringSlice("blackwell"), SIMD(10), StringSlice("sm_100a"), Int(148)) - Int(1024))
sm100_tmem_colsβ
comptime sm100_tmem_cols = Int(512)
Methodsβ
__init__β
def __init__(*, num_q_heads: Int, group: Int, qk_depth: Int, ov_depth: Int, swizzle_mode: TensorMapSwizzle, page_size: Int) -> Self
BM_effβ
def BM_eff(self) -> Int
Number of distinct sequence positions per CTA tile.
When fuse_gqa, each CTA tile covers BM // group seq positions Γ group heads = BM physical rows.
Returns:
rope_depthβ
num_qoβ
correction_smem_elementsβ
num_active_warps_per_groupβ
num_active_threads_per_groupβ
supportedβ
descriptionβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!