For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MLAPrefillSparse
struct MLAPrefillSparse[KVLUTType: MHAOperand, output_dtype: DType, config: MLASparseConfig[config.qkv_dtype, config.b_topk_, config.num_mbars_, config.q_smem_depth_, config.q_tmem_depth_]]
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
accum_dtypeβ
comptime accum_dtype = DType.float32
B_TOPK_PER_CTAβ
comptime B_TOPK_PER_CTA = (config.B_TOPK // 2)
FP8_K_SWIZZLEβ
comptime FP8_K_SWIZZLE = TensorMapSwizzle.SWIZZLE_NONE
FP8_V_SWIZZLEβ
comptime FP8_V_SWIZZLE = TensorMapSwizzle.SWIZZLE_NONE
FULL_Q_TYPEβ
comptime FULL_Q_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.FULL_Q_TYPE
k_desc_shapeβ
comptime k_desc_shape = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].k_gather_box)
k_gather_boxβ
comptime k_gather_box = _gather4_box_width[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_width, MLAPrefillSparse[KVLUTType, output_dtype, config].k_swizzle_mode]()
k_swizzle_modeβ
comptime k_swizzle_mode = config.k_swizzle_mode
k_tile_heightβ
comptime k_tile_height = MLAPrefillSparse[KVLUTType, output_dtype, config].B_TOPK_PER_CTA
k_tile_shapeβ
comptime k_tile_shape = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_height, MLAPrefillSparse[KVLUTType, output_dtype, config].k_gather_box)
k_tile_widthβ
comptime k_tile_width = config.qk_depth
k_tma_desc_shape_fp8β
comptime k_tma_desc_shape_fp8 = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tma_gather_box_fp8)
k_tma_dtype_fp8β
comptime k_tma_dtype_fp8 = DType.int64
k_tma_gather_box_fp8β
comptime k_tma_gather_box_fp8 = _gather4_box_width[DType.int64, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tma_tile_width_fp8, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tma_swizzle_fp8]()
k_tma_swizzle_fp8β
comptime k_tma_swizzle_fp8 = MLAPrefillSparse[KVLUTType, output_dtype, config].FP8_K_SWIZZLE
k_tma_tile_shape_fp8β
comptime k_tma_tile_shape_fp8 = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_height, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tma_gather_box_fp8)
k_tma_tile_width_fp8β
comptime k_tma_tile_width_fp8 = (config // 8)
NUM_Q_HEADS_PER_CTAβ
comptime NUM_Q_HEADS_PER_CTA = (config // 2)
NUM_SV_ATOMSβ
comptime NUM_SV_ATOMS = 2
o_desc_shapeβ
comptime o_desc_shape = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].NUM_Q_HEADS_PER_CTA, 64)
o_tile_shapeβ
comptime o_tile_shape = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].NUM_Q_HEADS_PER_CTA, config)
O_TMEM_ADDRβ
comptime O_TMEM_ADDR = 0
O_TMEM_ADDR_ATOM2β
comptime O_TMEM_ADDR_ATOM2 = (0 + MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM)
O_TYPEβ
comptime O_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.O_TYPE
P_TMEM_ADDRβ
comptime P_TMEM_ADDR = 256
q_desc_shapeβ
comptime q_desc_shape = _default_desc_shape[3, MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].q_tile_shape, config.q_swizzle_mode]()
q_smem_depthβ
comptime q_smem_depth = config.q_smem_depth
q_tile_shapeβ
comptime q_tile_shape = Index[Int, Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].NUM_Q_HEADS_PER_CTA, config)
Q_TMEM_ADDRβ
comptime Q_TMEM_ADDR = (512 - (MLAPrefillSparse[KVLUTType, output_dtype, config].q_tmem_depth // 2))
q_tmem_depthβ
comptime q_tmem_depth = config.q_tmem_depth
QKMMAOpTypeβ
comptime QKMMAOpType = QKMMAOp[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, DType.float32, config]
qkv_dtypeβ
comptime qkv_dtype
qkv_dtype_sizeβ
comptime qkv_dtype_size = size_of[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype]()
SHARED_QKV_TYPEβ
comptime SHARED_QKV_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.SHARED_QKV_TYPE
SMemTypeβ
comptime SMemType = MLASparseSharedMemory[config]
SVMMATypeβ
comptime SVMMAType = SVMMAType[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, DType.float32, config]
V_BMN_PER_ATOMβ
comptime V_BMN_PER_ATOM = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_DEPTH_PER_CTA // 2)
V_DEPTH_PER_CTAβ
comptime V_DEPTH_PER_CTA = (config // 2)
v_desc_shapeβ
comptime v_desc_shape = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].v_gather_box)
v_gather_boxβ
comptime v_gather_box = _gather4_box_width[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_width, MLAPrefillSparse[KVLUTType, output_dtype, config].v_swizzle_mode]()
V_SMEM_COLS_PER_CTAβ
comptime V_SMEM_COLS_PER_CTA = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM * 2)
v_swizzle_modeβ
comptime v_swizzle_mode = TensorMapSwizzle.SWIZZLE_128B
v_tile_heightβ
comptime v_tile_height = (config.B_TOPK // 2)
v_tile_shapeβ
comptime v_tile_shape = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_height, MLAPrefillSparse[KVLUTType, output_dtype, config].v_gather_box)
v_tile_widthβ
comptime v_tile_width = MLAPrefillSparse[KVLUTType, output_dtype, config].V_SMEM_COLS_PER_CTA
v_tma_desc_shape_fp8β
comptime v_tma_desc_shape_fp8 = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_gather_box_fp8)
v_tma_dtype_fp8β
comptime v_tma_dtype_fp8 = DType.int64
v_tma_gather_box_fp8β
comptime v_tma_gather_box_fp8 = _gather4_box_width[DType.int64, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_tile_width_fp8, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_swizzle_fp8]()
v_tma_swizzle_fp8β
comptime v_tma_swizzle_fp8 = MLAPrefillSparse[KVLUTType, output_dtype, config].FP8_V_SWIZZLE
v_tma_tile_height_fp8β
comptime v_tma_tile_height_fp8 = config.B_TOPK
v_tma_tile_shape_fp8β
comptime v_tma_tile_shape_fp8 = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_tile_height_fp8, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_gather_box_fp8)
v_tma_tile_width_fp8β
comptime v_tma_tile_width_fp8 = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM // 8)
Methodsβ
kernelβ
static def kernel[TopKLengthLayout: TensorLayout, IndicesLayout: TensorLayout](q_tma_op: TMATensorTile[Self.qkv_dtype, 3, Self.q_tile_shape, Self.q_desc_shape], k_tma_op: TMATensorTile[Self.qkv_dtype, 2, Self.k_tile_shape, Self.k_desc_shape], v_tma_op: TMATensorTile[Self.qkv_dtype, 2, Self.v_tile_shape, Self.v_desc_shape], o_tma_op: TMATensorTile[output_dtype, 2, Self.o_tile_shape, Self.o_desc_shape], topk_lengths: TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin], indices: TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin], kv_lut: KVLUTType, scale: Float32, attn_sink_ptr: Optional[UnsafePointer[Float32, ImmutAnyOrigin]], indices_stride: Int32, output_gmem_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin]) where (TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin].flat_rank == 1) if (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1) else (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1)
mmaβ
static def mma[fp8_active: Bool = False](q_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], k_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], s_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], v_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], k_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p0_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p1_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], so_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], p_free: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ss_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ts_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, num_k_blocks: UInt32)
cp_q_from_smem_to_tmemβ
static def cp_q_from_smem_to_tmem(smem_desc: MMASmemDescriptorPair, tmem_addr: UInt32)
k_tma_gather4_loadβ
static def k_tma_gather4_load[col_range: Tuple[UInt32, UInt32], num_rows: Int](tma_op: TMATensorTile[Self.qkv_dtype, 2], smem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem_tensor: TileTensor[Self.qkv_dtype, address_space=AddressSpace.SHARED, linear_idx_type=smem_tensor.linear_idx_type, element_size=smem_tensor.element_size], local_indices: InlineArray[SIMD[DType.int32, 4], num_rows], warp_idx: UInt32)
v_tma_gather4_loadβ
static def v_tma_gather4_load[local_row_range: Tuple[Int, Int]](tma_op: TMATensorTile[Self.qkv_dtype, 2], smem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem_tensor: TileTensor[Self.qkv_dtype, address_space=AddressSpace.SHARED, linear_idx_type=smem_tensor.linear_idx_type, element_size=smem_tensor.element_size], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], warp_idx: UInt32, k: UInt32, cta_id: UInt32, indices_base: UInt32)
kv_valid_producerβ
static def kv_valid_producer(indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], is_k_valid_ptr: UnsafePointer[UInt8, address_space=AddressSpace.SHARED], k_valid_ready_ptr: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_valid_free_ptr: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], lane_idx: UInt32, indices_base: UInt32, num_kv_rows: Int32, top_k_length: Int32, num_k_blocks: Int)
load_kβ
static def load_k(k_tma_op: TMATensorTile[Self.qkv_dtype, 2, Self.k_tile_shape, Self.k_desc_shape], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], qk_ss_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ts_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, warp_idx: UInt32, num_kv_rows: Int32, indices_base: UInt32)
load_vβ
static def load_v(v_tma_op: TMATensorTile[Self.qkv_dtype, 2, Self.v_tile_shape, Self.v_desc_shape], v_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], sv_p0_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p1_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k: UInt32, warp_idx: UInt32, cta_id: UInt32, indices_base: UInt32)
load_k_fp8_tmaβ
static def load_k_fp8_tma(k_tma_op_fp8: TMATensorTile[DType.int64, 2, Self.k_tma_tile_shape_fp8, Self.k_tma_desc_shape_fp8], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k_smem_fp8_ptr: UnsafePointer[Float8_e4m3fn, address_space=AddressSpace.SHARED], k_fp8_tma_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, warp_idx: UInt32, indices_base: UInt32)
load_k_scales_to_smemβ
static def load_k_scales_to_smem[scale_block_size: Int](scales_ptr: UnsafePointer[Float32, ImmutAnyOrigin], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], kv_lut: KVLUTType, k_scales_smem_ptr: UnsafePointer[Float32, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, indices_base: UInt32, num_kv_rows: Int32)
convert_k_fp8_to_bf16β
static def convert_k_fp8_to_bf16[scale_block_size: Int](k_smem_fp8_ptr: UnsafePointer[Float8_e4m3fn, address_space=AddressSpace.SHARED], k_smem_bf16_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], k_scales_smem_ptr: UnsafePointer[Float32, address_space=AddressSpace.SHARED])
load_v_fp8_tmaβ
static def load_v_fp8_tma(v_tma_op_fp8: TMATensorTile[DType.int64, 2, Self.v_tma_tile_shape_fp8, Self.v_tma_desc_shape_fp8], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], v_smem_fp8_ptr: UnsafePointer[Float8_e4m3fn, address_space=AddressSpace.SHARED], v_fp8_tma_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, warp_idx: UInt32, indices_base: UInt32)
load_v_scales_to_smemβ
static def load_v_scales_to_smem[scale_block_size: Int](scales_ptr: UnsafePointer[Float32, ImmutAnyOrigin], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], kv_lut: KVLUTType, v_scales_smem_ptr: UnsafePointer[Float32, address_space=AddressSpace.SHARED], k: UInt32, indices_base: UInt32, num_kv_rows: Int32)
convert_v_fp8_to_bf16β
static def convert_v_fp8_to_bf16[scale_block_size: Int](v_smem_fp8_ptr: UnsafePointer[Float8_e4m3fn, address_space=AddressSpace.SHARED], v_smem_bf16_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], v_scales_smem_ptr: UnsafePointer[Float32, address_space=AddressSpace.SHARED])
kernel_fp8β
static def kernel_fp8[TopKLengthLayout: TensorLayout, IndicesLayout: TensorLayout, scale_block_size: Int](q_tma_op: TMATensorTile[Self.qkv_dtype, 3, Self.q_tile_shape, Self.q_desc_shape], k_tma_op_fp8: TMATensorTile[DType.int64, 2, Self.k_tma_tile_shape_fp8, Self.k_tma_desc_shape_fp8], v_tma_op_fp8: TMATensorTile[DType.int64, 2, Self.v_tma_tile_shape_fp8, Self.v_tma_desc_shape_fp8], o_tma_op: TMATensorTile[output_dtype, 2, Self.o_tile_shape, Self.o_desc_shape], topk_lengths: TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin], indices: TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin], kv_lut: KVLUTType, scale: Float32, attn_sink_ptr: Optional[UnsafePointer[Float32, ImmutAnyOrigin]], indices_stride: Int32, output_gmem_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin], scales_ptr: UnsafePointer[Float32, ImmutAnyOrigin]) where (TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin].flat_rank == 1) if (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1) else (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!