Skip to main content

Mojo struct

MLA_SM100_Decode_Config

struct MLA_SM100_Decode_Config

Fields

  • MMA_M (Int):
  • MMA_PV_N (Int):
  • MMA_QK_N (Int):
  • BM (Int):
  • BN (Int):
  • BK0 (Int):
  • BK1 (Int):
  • q_depth (Int):
  • depth (Int):
  • padded_depth (Int):
  • padded_q_depth (Int):
  • rope_depth (Int):
  • group (Int):
  • num_q_heads (Int):
  • num_kv_heads (Int):
  • tmem_used (Int):
  • num_kv_stages (Int):
  • smem_used (Int):
  • dtype_size (Int):
  • num_threads (Int):
  • swizzle_mode (TensorMapSwizzle):
  • kv_swizzle_mode (TensorMapSwizzle):
  • decoding_warp_split_k (Bool):
  • out_rows (Int):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = True

cta_group

comptime cta_group = 1

mbar_size

comptime mbar_size = size_of[DType.int64]()

MMA_K

comptime MMA_K = 16

sm100_smem_carveout

comptime sm100_smem_carveout = (B200 - 1024)

sm100_tmem_cols

comptime sm100_tmem_cols = 512

TMEM_CORR_LI

comptime TMEM_CORR_LI = (MLA_SM100_Decode_Config.TMEM_CORR_SCALE + 1)

TMEM_CORR_SCALE

comptime TMEM_CORR_SCALE = (MLA_SM100_Decode_Config.TMEM_S1 + 32)

TMEM_O

comptime TMEM_O = 0

TMEM_S0

comptime TMEM_S0 = (MLA_SM100_Decode_Config.TMEM_O + 256)

TMEM_S1

comptime TMEM_S1 = (MLA_SM100_Decode_Config.TMEM_S0 + 32)

Methods

__init__

__init__(out self, *, num_q_heads: Int, group: Int, depth: Int, q_depth: Int, dtype_size: Int, swizzle_mode: TensorMapSwizzle, kv_swizzle_mode: TensorMapSwizzle, page_size: Int, decoding_warp_split_k: Bool, fixed_transaction_barriers: Int, num_threads: Int)

supported

supported(self) -> Bool

Returns:

Bool

Was this page helpful?