Skip to main content

Mojo struct

MLASparseSharedMemory

struct MLASparseSharedMemory[config: MLASparseConfig[config.qkv_dtype]]

Fields​

  • ​qkvo_union (UnsafeUnion[InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].FULL_Q_SIZE], InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].SHARED_QKV_SIZE], InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].O_SIZE]]):
  • ​scores (InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].S_SIZE]):
  • ​p (InlineArray[Float32, MLASparseSharedMemory[config].S_SIZE]):
  • ​prologue_q (InlineArray[SharedMemBarrier, 1]):
  • ​prologue_q_cp (InlineArray[SharedMemBarrier, 1]):
  • ​qk_ss_done (InlineArray[SharedMemBarrier, 2]):
  • ​qk_ts_done (InlineArray[SharedMemBarrier, 2]):
  • ​sv_p0_done (InlineArray[SharedMemBarrier, 2]):
  • ​sv_p1_done (InlineArray[SharedMemBarrier, 2]):
  • ​k_p0_ready (InlineArray[SharedMemBarrier, 2]):
  • ​k_p1_ready (InlineArray[SharedMemBarrier, 2]):
  • ​v_p0_ready (InlineArray[SharedMemBarrier, 2]):
  • ​v_p1_ready (InlineArray[SharedMemBarrier, 2]):
  • ​p_free (InlineArray[SharedMemBarrier, 2]):
  • ​so_ready (InlineArray[SharedMemBarrier, 2]):
  • ​k_valid_ready (InlineArray[SharedMemBarrier, 2]):
  • ​k_valid_free (InlineArray[SharedMemBarrier, 2]):
  • ​is_k_valid (InlineArray[UInt8, 32]):
  • ​tmem_addr (InlineArray[UInt32, 1]):
  • ​rowwise_max (InlineArray[Float32, WARPGROUP_SIZE]):
  • ​rowwise_sum (InlineArray[Float32, WARPGROUP_SIZE]):

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

BH​

comptime BH = (MLASparseSharedMemory[config].num_q_heads // 2)

FULL_Q_SIZE​

comptime FULL_Q_SIZE = (MLASparseSharedMemory[config].BH * MLASparseSharedMemory[config].qk_depth)

FULL_Q_TYPE​

comptime FULL_Q_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].FULL_Q_SIZE]

INDICES_PER_LANE​

comptime INDICES_PER_LANE = 8

K_SIZE​

comptime K_SIZE = (64 * config)

MASK_BYTES_PER_BUF​

comptime MASK_BYTES_PER_BUF = 16

NUM_KV_VALID_LANES​

comptime NUM_KV_VALID_LANES = MLASparseSharedMemory[config].MASK_BYTES_PER_BUF

num_mbars​

comptime num_mbars = 2

num_q_heads​

comptime num_q_heads = config.num_q_heads

NUM_SV_ATOMS​

comptime NUM_SV_ATOMS = 2

O_SIZE​

comptime O_SIZE = (MLASparseSharedMemory[config].BH * config)

O_TYPE​

comptime O_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].O_SIZE]

qk_depth​

comptime qk_depth = config.qk_depth

qkv_dtype​

comptime qkv_dtype

S_SIZE​

comptime S_SIZE = (MLASparseSharedMemory[config].BH * 128)

SHARED_Q_SIZE​

comptime SHARED_Q_SIZE = (MLASparseSharedMemory[config].BH * 192)

SHARED_QKV_SIZE​

comptime SHARED_QKV_SIZE = ((MLASparseSharedMemory[config].SHARED_Q_SIZE + MLASparseSharedMemory[config].K_SIZE) + MLASparseSharedMemory[config].V_SIZE)

SHARED_QKV_TYPE​

comptime SHARED_QKV_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].SHARED_QKV_SIZE]

TOPK_PER_CTA​

comptime TOPK_PER_CTA = 64

V_BMN_PER_ATOM​

comptime V_BMN_PER_ATOM = (MLASparseSharedMemory[config].V_DEPTH_PER_CTA // 2)

V_DEPTH_PER_CTA​

comptime V_DEPTH_PER_CTA = (config // 2)

V_SIZE​

comptime V_SIZE = (128 * MLASparseSharedMemory[config].V_SMEM_COLS_PER_CTA)

V_SMEM_COLS_PER_CTA​

comptime V_SMEM_COLS_PER_CTA = (MLASparseSharedMemory[config].V_BMN_PER_ATOM * 2)