IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MLASparseSharedMemory

struct MLASparseSharedMemory[config: MLASparseConfig[config.qkv_dtype, config.b_topk_, config.num_mbars_, config.q_smem_depth_, config.q_tmem_depth_]]

Fields​

  • ​qkvo_union (UnsafeUnion[InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].FULL_Q_SIZE], InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].SHARED_QKV_SIZE], InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].O_SIZE]]):
  • ​scores (InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].S_SIZE]):
  • ​p (InlineArray[Float32, MLASparseSharedMemory[config].S_SIZE]):
  • ​prologue_q (InlineArray[SharedMemBarrier, 1]):
  • ​prologue_q_cp (InlineArray[SharedMemBarrier, 1]):
  • ​qk_ss_done (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​qk_ts_done (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​sv_p0_done (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​sv_p1_done (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​k_p0_ready (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​k_p1_ready (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​v_p0_ready (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​v_p1_ready (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​p_free (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​so_ready (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​k_valid_ready (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​k_valid_free (InlineArray[SharedMemBarrier, MLASparseSharedMemory[config].num_mbars]):
  • ​is_k_valid (InlineArray[UInt8, (MLASparseSharedMemory[config].num_mbars * MLASparseSharedMemory[config].MASK_BYTES_PER_BUF)]):
  • ​tmem_addr (InlineArray[UInt32, 1]):
  • ​rowwise_max (InlineArray[Float32, WARPGROUP_SIZE]):
  • ​rowwise_sum (InlineArray[Float32, WARPGROUP_SIZE]):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

BH​

comptime BH = (MLASparseSharedMemory[config].num_q_heads // 2)

FULL_Q_SIZE​

comptime FULL_Q_SIZE = (MLASparseSharedMemory[config].BH * MLASparseSharedMemory[config].qk_depth)

FULL_Q_TYPE​

comptime FULL_Q_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].FULL_Q_SIZE]

INDICES_PER_LANE​

comptime INDICES_PER_LANE = 8

K_SIZE​

comptime K_SIZE = (MLASparseSharedMemory[config].TOPK_PER_CTA * config)

MASK_BYTES_PER_BUF​

comptime MASK_BYTES_PER_BUF = (config.B_TOPK // 8)

NUM_KV_VALID_LANES​

comptime NUM_KV_VALID_LANES = MLASparseSharedMemory[config].MASK_BYTES_PER_BUF

num_mbars​

comptime num_mbars = config.num_mbars

num_q_heads​

comptime num_q_heads = config.num_q_heads

NUM_SV_ATOMS​

comptime NUM_SV_ATOMS = 2

O_SIZE​

comptime O_SIZE = (MLASparseSharedMemory[config].BH * config)

O_TYPE​

comptime O_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].O_SIZE]

qk_depth​

comptime qk_depth = config.qk_depth

qkv_dtype​

comptime qkv_dtype

S_SIZE​

comptime S_SIZE = (MLASparseSharedMemory[config].BH * config.B_TOPK)

SHARED_Q_SIZE​

comptime SHARED_Q_SIZE = (MLASparseSharedMemory[config].BH * config.q_smem_depth)

SHARED_QKV_SIZE​

comptime SHARED_QKV_SIZE = ((MLASparseSharedMemory[config].SHARED_Q_SIZE + MLASparseSharedMemory[config].K_SIZE) + MLASparseSharedMemory[config].V_SIZE)

SHARED_QKV_TYPE​

comptime SHARED_QKV_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].SHARED_QKV_SIZE]

TOPK_PER_CTA​

comptime TOPK_PER_CTA = (config.B_TOPK // 2)

V_BMN_PER_ATOM​

comptime V_BMN_PER_ATOM = (MLASparseSharedMemory[config].V_DEPTH_PER_CTA // 2)

V_DEPTH_PER_CTA​

comptime V_DEPTH_PER_CTA = (config // 2)

V_SIZE​

comptime V_SIZE = (config.B_TOPK * MLASparseSharedMemory[config].V_SMEM_COLS_PER_CTA)

V_SMEM_COLS_PER_CTA​

comptime V_SMEM_COLS_PER_CTA = (MLASparseSharedMemory[config].V_BMN_PER_ATOM * 2)