IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MLAPrefillSparse

struct MLAPrefillSparse[KVLUTType: MHAOperand, output_dtype: DType, config: MLASparseConfig[config.qkv_dtype, config.b_topk_, config.num_mbars_, config.q_smem_depth_, config.q_tmem_depth_]]

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

accum_dtype​

comptime accum_dtype = DType.float32

B_TOPK_PER_CTA​

comptime B_TOPK_PER_CTA = (config.B_TOPK // 2)

FP8_K_SWIZZLE​

comptime FP8_K_SWIZZLE = TensorMapSwizzle.SWIZZLE_NONE

FP8_V_SWIZZLE​

comptime FP8_V_SWIZZLE = TensorMapSwizzle.SWIZZLE_NONE

FULL_Q_TYPE​

comptime FULL_Q_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.FULL_Q_TYPE

k_desc_shape​

comptime k_desc_shape = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].k_gather_box)

k_gather_box​

comptime k_gather_box = _gather4_box_width[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_width, MLAPrefillSparse[KVLUTType, output_dtype, config].k_swizzle_mode]()

k_swizzle_mode​

comptime k_swizzle_mode = config.k_swizzle_mode

k_tile_height​

comptime k_tile_height = MLAPrefillSparse[KVLUTType, output_dtype, config].B_TOPK_PER_CTA

k_tile_shape​

comptime k_tile_shape = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_height, MLAPrefillSparse[KVLUTType, output_dtype, config].k_gather_box)

k_tile_width​

comptime k_tile_width = config.qk_depth

k_tma_desc_shape_fp8​

comptime k_tma_desc_shape_fp8 = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tma_gather_box_fp8)

k_tma_dtype_fp8​

comptime k_tma_dtype_fp8 = DType.int64

k_tma_gather_box_fp8​

comptime k_tma_gather_box_fp8 = _gather4_box_width[DType.int64, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tma_tile_width_fp8, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tma_swizzle_fp8]()

k_tma_swizzle_fp8​

comptime k_tma_swizzle_fp8 = MLAPrefillSparse[KVLUTType, output_dtype, config].FP8_K_SWIZZLE

k_tma_tile_shape_fp8​

comptime k_tma_tile_shape_fp8 = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].k_tile_height, MLAPrefillSparse[KVLUTType, output_dtype, config].k_tma_gather_box_fp8)

k_tma_tile_width_fp8​

comptime k_tma_tile_width_fp8 = (config // 8)

NUM_Q_HEADS_PER_CTA​

comptime NUM_Q_HEADS_PER_CTA = (config // 2)

NUM_SV_ATOMS​

comptime NUM_SV_ATOMS = 2

o_desc_shape​

comptime o_desc_shape = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].NUM_Q_HEADS_PER_CTA, 64)

o_tile_shape​

comptime o_tile_shape = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].NUM_Q_HEADS_PER_CTA, config)

O_TMEM_ADDR​

comptime O_TMEM_ADDR = 0

O_TMEM_ADDR_ATOM2​

comptime O_TMEM_ADDR_ATOM2 = (0 + MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM)

O_TYPE​

comptime O_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.O_TYPE

P_TMEM_ADDR​

comptime P_TMEM_ADDR = 256

q_desc_shape​

comptime q_desc_shape = _default_desc_shape[3, MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].q_tile_shape, config.q_swizzle_mode]()

q_smem_depth​

comptime q_smem_depth = config.q_smem_depth

q_tile_shape​

comptime q_tile_shape = Index[Int, Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].NUM_Q_HEADS_PER_CTA, config)

Q_TMEM_ADDR​

comptime Q_TMEM_ADDR = (512 - (MLAPrefillSparse[KVLUTType, output_dtype, config].q_tmem_depth // 2))

q_tmem_depth​

comptime q_tmem_depth = config.q_tmem_depth

QKMMAOpType​

comptime QKMMAOpType = QKMMAOp[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, DType.float32, config]

qkv_dtype​

comptime qkv_dtype

qkv_dtype_size​

comptime qkv_dtype_size = size_of[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype]()

SHARED_QKV_TYPE​

comptime SHARED_QKV_TYPE = MLAPrefillSparse[KVLUTType, output_dtype, config].SMemType.SHARED_QKV_TYPE

SMemType​

comptime SMemType = MLASparseSharedMemory[config]

SVMMAType​

comptime SVMMAType = SVMMAType[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, DType.float32, config]

V_BMN_PER_ATOM​

comptime V_BMN_PER_ATOM = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_DEPTH_PER_CTA // 2)

V_DEPTH_PER_CTA​

comptime V_DEPTH_PER_CTA = (config // 2)

v_desc_shape​

comptime v_desc_shape = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].v_gather_box)

v_gather_box​

comptime v_gather_box = _gather4_box_width[MLAPrefillSparse[KVLUTType, output_dtype, config].qkv_dtype, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_width, MLAPrefillSparse[KVLUTType, output_dtype, config].v_swizzle_mode]()

V_SMEM_COLS_PER_CTA​

comptime V_SMEM_COLS_PER_CTA = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM * 2)

v_swizzle_mode​

comptime v_swizzle_mode = TensorMapSwizzle.SWIZZLE_128B

v_tile_height​

comptime v_tile_height = (config.B_TOPK // 2)

v_tile_shape​

comptime v_tile_shape = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].v_tile_height, MLAPrefillSparse[KVLUTType, output_dtype, config].v_gather_box)

v_tile_width​

comptime v_tile_width = MLAPrefillSparse[KVLUTType, output_dtype, config].V_SMEM_COLS_PER_CTA

v_tma_desc_shape_fp8​

comptime v_tma_desc_shape_fp8 = Index[Int, Int](1, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_gather_box_fp8)

v_tma_dtype_fp8​

comptime v_tma_dtype_fp8 = DType.int64

v_tma_gather_box_fp8​

comptime v_tma_gather_box_fp8 = _gather4_box_width[DType.int64, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_tile_width_fp8, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_swizzle_fp8]()

v_tma_swizzle_fp8​

comptime v_tma_swizzle_fp8 = MLAPrefillSparse[KVLUTType, output_dtype, config].FP8_V_SWIZZLE

v_tma_tile_height_fp8​

comptime v_tma_tile_height_fp8 = config.B_TOPK

v_tma_tile_shape_fp8​

comptime v_tma_tile_shape_fp8 = Index[Int, Int](MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_tile_height_fp8, MLAPrefillSparse[KVLUTType, output_dtype, config].v_tma_gather_box_fp8)

v_tma_tile_width_fp8​

comptime v_tma_tile_width_fp8 = (MLAPrefillSparse[KVLUTType, output_dtype, config].V_BMN_PER_ATOM // 8)

Methods​

kernel​

static def kernel[TopKLengthLayout: TensorLayout, IndicesLayout: TensorLayout](q_tma_op: TMATensorTile[Self.qkv_dtype, 3, Self.q_tile_shape, Self.q_desc_shape], k_tma_op: TMATensorTile[Self.qkv_dtype, 2, Self.k_tile_shape, Self.k_desc_shape], v_tma_op: TMATensorTile[Self.qkv_dtype, 2, Self.v_tile_shape, Self.v_desc_shape], o_tma_op: TMATensorTile[output_dtype, 2, Self.o_tile_shape, Self.o_desc_shape], topk_lengths: TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin], indices: TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin], kv_lut: KVLUTType, scale: Float32, attn_sink_ptr: Optional[UnsafePointer[Float32, ImmutAnyOrigin]], indices_stride: Int32, output_gmem_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin]) where (TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin].flat_rank == 1) if (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1) else (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1)

mma​

static def mma[fp8_active: Bool = False](q_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], k_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], s_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], v_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], k_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p0_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p1_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], so_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], p_free: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ss_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ts_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, num_k_blocks: UInt32)

cp_q_from_smem_to_tmem​

static def cp_q_from_smem_to_tmem(smem_desc: MMASmemDescriptorPair, tmem_addr: UInt32)

k_tma_gather4_load​

static def k_tma_gather4_load[col_range: Tuple[UInt32, UInt32], num_rows: Int](tma_op: TMATensorTile[Self.qkv_dtype, 2], smem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem_tensor: TileTensor[Self.qkv_dtype, address_space=AddressSpace.SHARED, linear_idx_type=smem_tensor.linear_idx_type, element_size=smem_tensor.element_size], local_indices: InlineArray[SIMD[DType.int32, 4], num_rows], warp_idx: UInt32)

v_tma_gather4_load​

static def v_tma_gather4_load[local_row_range: Tuple[Int, Int]](tma_op: TMATensorTile[Self.qkv_dtype, 2], smem_barrier: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem_tensor: TileTensor[Self.qkv_dtype, address_space=AddressSpace.SHARED, linear_idx_type=smem_tensor.linear_idx_type, element_size=smem_tensor.element_size], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], warp_idx: UInt32, k: UInt32, cta_id: UInt32, indices_base: UInt32)

kv_valid_producer​

static def kv_valid_producer(indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], is_k_valid_ptr: UnsafePointer[UInt8, address_space=AddressSpace.SHARED], k_valid_ready_ptr: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_valid_free_ptr: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], lane_idx: UInt32, indices_base: UInt32, num_kv_rows: Int32, top_k_length: Int32, num_k_blocks: Int)

load_k​

static def load_k(k_tma_op: TMATensorTile[Self.qkv_dtype, 2, Self.k_tile_shape, Self.k_desc_shape], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], qk_ss_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], qk_ts_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, warp_idx: UInt32, num_kv_rows: Int32, indices_base: UInt32)

load_v​

static def load_v(v_tma_op: TMATensorTile[Self.qkv_dtype, 2, Self.v_tile_shape, Self.v_desc_shape], v_smem_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], sv_p0_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], sv_p1_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p0_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], v_p1_ready: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k: UInt32, warp_idx: UInt32, cta_id: UInt32, indices_base: UInt32)

load_k_fp8_tma​

static def load_k_fp8_tma(k_tma_op_fp8: TMATensorTile[DType.int64, 2, Self.k_tma_tile_shape_fp8, Self.k_tma_desc_shape_fp8], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], k_smem_fp8_ptr: UnsafePointer[Float8_e4m3fn, address_space=AddressSpace.SHARED], k_fp8_tma_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, warp_idx: UInt32, indices_base: UInt32)

load_k_scales_to_smem​

static def load_k_scales_to_smem[scale_block_size: Int](scales_ptr: UnsafePointer[Float32, ImmutAnyOrigin], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], kv_lut: KVLUTType, k_scales_smem_ptr: UnsafePointer[Float32, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, indices_base: UInt32, num_kv_rows: Int32)

convert_k_fp8_to_bf16​

static def convert_k_fp8_to_bf16[scale_block_size: Int](k_smem_fp8_ptr: UnsafePointer[Float8_e4m3fn, address_space=AddressSpace.SHARED], k_smem_bf16_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], k_scales_smem_ptr: UnsafePointer[Float32, address_space=AddressSpace.SHARED])

load_v_fp8_tma​

static def load_v_fp8_tma(v_tma_op_fp8: TMATensorTile[DType.int64, 2, Self.v_tma_tile_shape_fp8, Self.v_tma_desc_shape_fp8], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], v_smem_fp8_ptr: UnsafePointer[Float8_e4m3fn, address_space=AddressSpace.SHARED], v_fp8_tma_done: UnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], k: UInt32, cta_id: UInt32, warp_idx: UInt32, indices_base: UInt32)

load_v_scales_to_smem​

static def load_v_scales_to_smem[scale_block_size: Int](scales_ptr: UnsafePointer[Float32, ImmutAnyOrigin], indices: TileTensor[DType.uint32, linear_idx_type=indices.linear_idx_type, element_size=indices.element_size], kv_lut: KVLUTType, v_scales_smem_ptr: UnsafePointer[Float32, address_space=AddressSpace.SHARED], k: UInt32, indices_base: UInt32, num_kv_rows: Int32)

convert_v_fp8_to_bf16​

static def convert_v_fp8_to_bf16[scale_block_size: Int](v_smem_fp8_ptr: UnsafePointer[Float8_e4m3fn, address_space=AddressSpace.SHARED], v_smem_bf16_ptr: UnsafePointer[Scalar[Self.qkv_dtype], address_space=AddressSpace.SHARED], v_scales_smem_ptr: UnsafePointer[Float32, address_space=AddressSpace.SHARED])

kernel_fp8​

static def kernel_fp8[TopKLengthLayout: TensorLayout, IndicesLayout: TensorLayout, scale_block_size: Int](q_tma_op: TMATensorTile[Self.qkv_dtype, 3, Self.q_tile_shape, Self.q_desc_shape], k_tma_op_fp8: TMATensorTile[DType.int64, 2, Self.k_tma_tile_shape_fp8, Self.k_tma_desc_shape_fp8], v_tma_op_fp8: TMATensorTile[DType.int64, 2, Self.v_tma_tile_shape_fp8, Self.v_tma_desc_shape_fp8], o_tma_op: TMATensorTile[output_dtype, 2, Self.o_tile_shape, Self.o_desc_shape], topk_lengths: TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin], indices: TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin], kv_lut: KVLUTType, scale: Float32, attn_sink_ptr: Optional[UnsafePointer[Float32, ImmutAnyOrigin]], indices_stride: Int32, output_gmem_ptr: UnsafePointer[Scalar[output_dtype], MutAnyOrigin], scales_ptr: UnsafePointer[Float32, ImmutAnyOrigin]) where (TileTensor[DType.uint32, IndicesLayout, MutAnyOrigin].flat_rank == 1) if (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1) else (TileTensor[DType.uint32, TopKLengthLayout, MutAnyOrigin].flat_rank == 1)