Skip to main content

Mojo struct

MLAPrefillSparse

struct MLAPrefillSparse[KVLUTType: MHAOperand, output_dtype: DType, config: MLASparseConfig[config.qkv_dtype]]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_dtype​

comptime accum_dtype = DType.float32

B_TOPK_PER_CTA​

comptime B_TOPK_PER_CTA = 64

FULL_Q_TYPE​

comptime FULL_Q_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.FULL_Q_TYPE

k_desc_shape​

comptime k_desc_shape = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].k_gather_box)

k_gather_box​

comptime k_gather_box = _gather4_box_width[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_width, MLAPrefillSparse[KVLUTType, output_dtype, config].k_swizzle_mode]()

k_swizzle_mode​

comptime k_swizzle_mode = config.k_swizzle_mode

k_tile_height​

comptime k_tile_height = MLAPrefillSparse[KVLUTType, output_dtype, config].B_TOPK_PER_CTA

k_tile_shape​

comptime k_tile_shape = Index[Int, Int](64, MLAPrefillSparse[KVLUTType, output_dtype, config].k_gather_box)

k_tile_width​

comptime k_tile_width = config.qk_depth

NUM_Q_HEADS_PER_CTA​

comptime NUM_Q_HEADS_PER_CTA = (config // 2)

NUM_SV_ATOMS​

comptime NUM_SV_ATOMS = 2

O_TMEM_ADDR​

comptime O_TMEM_ADDR = 0

O_TMEM_ADDR_ATOM2​

comptime O_TMEM_ADDR_ATOM2 = (0 + MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM)

O_TYPE​

comptime O_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.O_TYPE

P_TMEM_ADDR​

comptime P_TMEM_ADDR = 256

q_desc_shape​

comptime q_desc_shape = _default_desc_shape[3, MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].q_tile_shape, config.q_swizzle_mode]()

q_smem_depth​

comptime q_smem_depth = config.q_smem_depth

q_tile_shape​

comptime q_tile_shape = Index[Int, Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].NUM_Q_HEADS_PER_CTA, config)

Q_TMEM_ADDR​

comptime Q_TMEM_ADDR = 320

q_tmem_depth​

comptime q_tmem_depth = config.q_tmem_depth

QKMMAOpType​

comptime QKMMAOpType = QKMMAOp[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, DType.float32, config]

qkv_dtype​

comptime qkv_dtype

qkv_dtype_size​

comptime qkv_dtype_size = size_of[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype]()

SHARED_QKV_TYPE​

comptime SHARED_QKV_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.SHARED_QKV_TYPE

SMemType​

comptime SMemType = MLASparseSharedMemory[config]

SVMMAType​

comptime SVMMAType = SVMMAType[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, DType.float32, config]

V_BMN_PER_ATOM​

comptime V_BMN_PER_ATOM = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_DEPTH_PER_CTA // 2)

V_DEPTH_PER_CTA​

comptime V_DEPTH_PER_CTA = (config // 2)

v_desc_shape​

comptime v_desc_shape = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].v_gather_box)

v_gather_box​

comptime v_gather_box = _gather4_box_width[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_width, MLAPrefillSparse[KVLUTType, output_dtype, config].v_swizzle_mode]()

V_SMEM_COLS_PER_CTA​

comptime V_SMEM_COLS_PER_CTA = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM * 2)

v_swizzle_mode​

comptime v_swizzle_mode = TensorMapSwizzle.SWIZZLE_128B

v_tile_height​

comptime v_tile_height = 64

v_tile_shape​

comptime v_tile_shape = Index[Int, Int](64, MLAPrefillSparse[KVLUTType, output_dtype, config].v_gather_box)

v_tile_width​

comptime v_tile_width = MLAPrefillSparse[KVLUTType, output_dtype, config].V_SMEM_COLS_PER_CTA

Methods​

kernel​

static kernel[TopKLengthLayout: TensorLayout, IndicesLayout: TensorLayout](q_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 3, MLAPrefillSparse[KVLUTType, output_dtype, config].q_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].q_desc_shape], k_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].k_desc_shape], v_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].v_desc_shape], topk_lengths: TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin], indices: TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin], kv_lut: KVLUTType, scale: Float32, attn_sink_ptr: UnsafePointer[Float32, ImmutAnyOrigin], indices_stride: Int32, output_gmem_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin]) where (TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin].flat_rank == 1) if (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1) else (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1)

mma​

static mma(q_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], k_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], s_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], v_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], k_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p0_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p1_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], so_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], p_free: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ss_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ts_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, num_k_blocks: UInt32)

cp_q_from_smem_to_tmem​

static cp_q_from_smem_to_tmem(smem_desc: MMASmemDescriptorPair, tmem_addr: UInt32)

k_tma_gather4_load​

static k_tma_gather4_load[col_range: Tuple[UInt32, UInt32], num_rows: Int](tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2], smem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem_tensor: TileTensor[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, address_space=AddressSpace.SHARED, linear_idx_type=smem_tensor.linear_idx_type, element_size=smem_tensor.element_size], local_indices: InlineArray[SIMD[DType.int32, 4], num_rows], warp_idx: UInt32)

v_tma_gather4_load​

static v_tma_gather4_load[local_row_range: Tuple[Int, Int]](tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2], smem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem_tensor: TileTensor[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, address_space=AddressSpace.SHARED, linear_idx_type=smem_tensor.linear_idx_type, element_size=smem_tensor.element_size], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], warp_idx: UInt32, k: UInt32, cta_id: UInt32, indices_base: UInt32)

kv_valid_producer​

static kv_valid_producer(indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], is_k_valid_ptr: UnsafePointer[UInt8, address_space=AddressSpace.SHARED], k_valid_ready_ptr: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_valid_free_ptr: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], lane_idx: UInt32, indices_base: UInt32, num_kv_rows: Int32, top_k_length: Int32, num_k_blocks: Int)

load_k​

static load_k(k_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].k_desc_shape], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], qk_ss_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ts_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, warp_idx: UInt32, num_kv_rows: Int32, indices_base: UInt32)

load_v​

static load_v(v_tma_op: TMATensorTile[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, 2, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_shape, MLAPrefillSparse[KVLUTType, output_dtype, config].v_desc_shape], v_smem_ptr: UnsafePointer[Scalar[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype], address_space=AddressSpace.SHARED], sv_p0_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p1_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k: UInt32, warp_idx: UInt32, cta_id: UInt32, indices_base: UInt32)