Mojo struct
MLAPrefillSparse
struct MLAPrefillSparse[KVLUTType: MHAOperand, output_dtype: DType, config: MLASparseConfig[config.qkv_dtype]]
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
accum_dtypeβ
comptime accum_dtype = DType.float32
B_TOPK_PER_CTAβ
comptime B_TOPK_PER_CTA = 64
FULL_Q_TYPEβ
comptime FULL_Q_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.FULL_Q_TYPE
k_desc_shapeβ
comptime k_desc_shape = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].k_gather_box)
k_gather_boxβ
comptime k_gather_box = _gather4_box_width[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_width, MLAPrefillSparse[KVLUTType, output_dtype, config].k_swizzle_mode]()
k_swizzle_modeβ
comptime k_swizzle_mode = config.k_swizzle_mode
k_tile_heightβ
comptime k_tile_height = MLAPrefillSparse[KVLUTType, output_dtype, config].B_TOPK_PER_CTA
k_tile_shapeβ
comptime k_tile_shape = Index[Int, Int](64, MLAPrefillSparse[KVLUTType, output_dtype, config].k_gather_box)
k_tile_widthβ
comptime k_tile_width = config.qk_depth
NUM_Q_HEADS_PER_CTAβ
comptime NUM_Q_HEADS_PER_CTA = (config // 2)
NUM_SV_ATOMSβ
comptime NUM_SV_ATOMS = 2
O_TMEM_ADDRβ
comptime O_TMEM_ADDR = 0
O_TMEM_ADDR_ATOM2β
comptime O_TMEM_ADDR_ATOM2 = (0 + MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM)
O_TYPEβ
comptime O_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.O_TYPE
P_TMEM_ADDRβ
comptime P_TMEM_ADDR = 256
q_desc_shapeβ
comptime q_desc_shape = _default_desc_shape[3, MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].q_tile_shape, config.q_swizzle_mode]()
q_smem_depthβ
comptime q_smem_depth = config.q_smem_depth
q_tile_shapeβ
comptime q_tile_shape = Index[Int, Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].NUM_Q_HEADS_PER_CTA, config)
Q_TMEM_ADDRβ
comptime Q_TMEM_ADDR = 320
q_tmem_depthβ
comptime q_tmem_depth = config.q_tmem_depth
QKMMAOpTypeβ
comptime QKMMAOpType = QKMMAOp[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, DType.float32, config]
qkv_dtypeβ
comptime qkv_dtype
qkv_dtype_sizeβ
comptime qkv_dtype_size = size_of[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype]()
SHARED_QKV_TYPEβ
comptime SHARED_QKV_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.SHARED_QKV_TYPE
SMemTypeβ
comptime SMemType = MLASparseSharedMemory[config]
SVMMATypeβ
comptime SVMMAType = SVMMAType[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, DType.float32, config]
V_BMN_PER_ATOMβ
comptime V_BMN_PER_ATOM = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_DEPTH_PER_CTA // 2)
V_DEPTH_PER_CTAβ
comptime V_DEPTH_PER_CTA = (config // 2)
v_desc_shapeβ
comptime v_desc_shape = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].v_gather_box)
v_gather_boxβ
comptime v_gather_box = _gather4_box_width[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_width, MLAPrefillSparse[KVLUTType, output_dtype, config].v_swizzle_mode]()
V_SMEM_COLS_PER_CTAβ
comptime V_SMEM_COLS_PER_CTA = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM * 2)
v_swizzle_modeβ
comptime v_swizzle_mode = TensorMapSwizzle.SWIZZLE_128B
v_tile_heightβ
comptime v_tile_height = 64
v_tile_shapeβ
comptime v_tile_shape = Index[Int, Int](64, MLAPrefillSparse[KVLUTType, output_dtype, config].v_gather_box)
v_tile_widthβ
comptime v_tile_width = MLAPrefillSparse[KVLUTType, output_dtype, config].V_SMEM_COLS_PER_CTA
Methodsβ
kernelβ
static kernel[TopKLengthLayout: TensorLayout, IndicesLayout: TensorLayout](q_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 3, MLAPrefillSparse[KVLUTType, output_dtype, config].q_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].q_desc_shape], k_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].k_desc_shape], v_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].v_desc_shape], topk_lengths: TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin], indices: TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin], kv_lut: KVLUTType, scale: Float32, attn_sink_ptr: UnsafePointer[Float32, ImmutAnyOrigin], indices_stride: Int32, output_gmem_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin]) where (TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin].flat_rank == 1) if (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1) else (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1)
mmaβ
static mma(q_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], k_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], s_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], v_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], k_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p0_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p1_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], so_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], p_free: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ss_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ts_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, num_k_blocks: UInt32)
cp_q_from_smem_to_tmemβ
static cp_q_from_smem_to_tmem(smem_desc: MMASmemDescriptorPair, tmem_addr: UInt32)
k_tma_gather4_loadβ
static k_tma_gather4_load[col_range: Tuple[UInt32, UInt32], num_rows: Int](tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2], smem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem_tensor: TileTensor[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, address_space=AddressSpace.SHARED, linear_idx_type=smem_tensor.linear_idx_type, element_size=smem_tensor.element_size], local_indices: InlineArray[SIMD[DType.int32, 4], num_rows], warp_idx: UInt32)
v_tma_gather4_loadβ
static v_tma_gather4_load[local_row_range: Tuple[Int, Int]](tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2], smem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem_tensor: TileTensor[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, address_space=AddressSpace.SHARED, linear_idx_type=smem_tensor.linear_idx_type, element_size=smem_tensor.element_size], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], warp_idx: UInt32, k: UInt32, cta_id: UInt32, indices_base: UInt32)
kv_valid_producerβ
static kv_valid_producer(indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], is_k_valid_ptr: UnsafePointer[UInt8, address_space=AddressSpace.SHARED], k_valid_ready_ptr: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_valid_free_ptr: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], lane_idx: UInt32, indices_base: UInt32, num_kv_rows: Int32, top_k_length: Int32, num_k_blocks: Int)
load_kβ
static load_k(k_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].k_desc_shape], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], qk_ss_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ts_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, warp_idx: UInt32, num_kv_rows: Int32, indices_base: UInt32)
load_vβ
static load_v(v_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].v_desc_shape], v_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], sv_p0_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p1_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k: UInt32, warp_idx: UInt32, cta_id: UInt32, indices_base: UInt32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!