Mojo struct
MLASparseSharedMemory
struct MLASparseSharedMemory[config: MLASparseConfig[config.qkv_dtype]]
Fieldsβ
- βqkvo_union (
UnsafeUnion[InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].FULL_Q_SIZE], InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].SHARED_QKV_SIZE], InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].O_SIZE]]): - βscores (
InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].S_SIZE]): - βp (
InlineArray[Float32, MLASparseSharedMemory[config].S_SIZE]): - βprologue_q (
InlineArray[SharedMemBarrier, 1]): - βprologue_q_cp (
InlineArray[SharedMemBarrier, 1]): - βqk_ss_done (
InlineArray[SharedMemBarrier, 2]): - βqk_ts_done (
InlineArray[SharedMemBarrier, 2]): - βsv_p0_done (
InlineArray[SharedMemBarrier, 2]): - βsv_p1_done (
InlineArray[SharedMemBarrier, 2]): - βk_p0_ready (
InlineArray[SharedMemBarrier, 2]): - βk_p1_ready (
InlineArray[SharedMemBarrier, 2]): - βv_p0_ready (
InlineArray[SharedMemBarrier, 2]): - βv_p1_ready (
InlineArray[SharedMemBarrier, 2]): - βp_free (
InlineArray[SharedMemBarrier, 2]): - βso_ready (
InlineArray[SharedMemBarrier, 2]): - βk_valid_ready (
InlineArray[SharedMemBarrier, 2]): - βk_valid_free (
InlineArray[SharedMemBarrier, 2]): - βis_k_valid (
InlineArray[UInt8, 32]): - βtmem_addr (
InlineArray[UInt32, 1]): - βrowwise_max (
InlineArray[Float32, WARPGROUP_SIZE]): - βrowwise_sum (
InlineArray[Float32, WARPGROUP_SIZE]):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
BHβ
comptime BH = (MLASparseSharedMemory[config].num_q_heads // 2)
FULL_Q_SIZEβ
comptime FULL_Q_SIZE = (MLASparseSharedMemory[config].BH * MLASparseSharedMemory[config].qk_depth)
FULL_Q_TYPEβ
comptime FULL_Q_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].FULL_Q_SIZE]
INDICES_PER_LANEβ
comptime INDICES_PER_LANE = 8
K_SIZEβ
comptime K_SIZE = (64 * config)
MASK_BYTES_PER_BUFβ
comptime MASK_BYTES_PER_BUF = 16
NUM_KV_VALID_LANESβ
comptime NUM_KV_VALID_LANES = MLASparseSharedMemory[config].MASK_BYTES_PER_BUF
num_mbarsβ
comptime num_mbars = 2
num_q_headsβ
comptime num_q_heads = config.num_q_heads
NUM_SV_ATOMSβ
comptime NUM_SV_ATOMS = 2
O_SIZEβ
comptime O_SIZE = (MLASparseSharedMemory[config].BH * config)
O_TYPEβ
comptime O_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].O_SIZE]
qk_depthβ
comptime qk_depth = config.qk_depth
qkv_dtypeβ
comptime qkv_dtype
S_SIZEβ
comptime S_SIZE = (MLASparseSharedMemory[config].BH * 128)
SHARED_Q_SIZEβ
comptime SHARED_Q_SIZE = (MLASparseSharedMemory[config].BH * 192)
SHARED_QKV_SIZEβ
comptime SHARED_QKV_SIZE = ((MLASparseSharedMemory[config].SHARED_Q_SIZE + MLASparseSharedMemory[config].K_SIZE) + MLASparseSharedMemory[config].V_SIZE)
SHARED_QKV_TYPEβ
comptime SHARED_QKV_TYPE = InlineArray[Scalar[MLASparseSharedMemory[config].qkv_dtype], MLASparseSharedMemory[config].SHARED_QKV_SIZE]
TOPK_PER_CTAβ
comptime TOPK_PER_CTA = 64
V_BMN_PER_ATOMβ
comptime V_BMN_PER_ATOM = (MLASparseSharedMemory[config].V_DEPTH_PER_CTA // 2)
V_DEPTH_PER_CTAβ
comptime V_DEPTH_PER_CTA = (config // 2)
V_SIZEβ
comptime V_SIZE = (128 * MLASparseSharedMemory[config].V_SMEM_COLS_PER_CTA)
V_SMEM_COLS_PER_CTAβ
comptime V_SMEM_COLS_PER_CTA = (MLASparseSharedMemory[config].V_BMN_PER_ATOM * 2)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!