Mojo struct
HKMhaPrefill
struct HKMhaPrefill[config: HKMhaConfig]
8-warp MHA forward kernel parameterized by HKMhaConfig.
Each block runs config.num_warps wave64 warps that share K/V
SMEM via cooperative DMA. Warp w owns Q rows [w * q_block_size, (w + 1) * q_block_size) of the block's stripe and carries its own
register-resident attention state.
Parametersβ
- βconfig (
HKMhaConfig): Shape configuration (HKMhaConfig).
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
BMβ
comptime BM = (HKMhaPrefill[config].NUM_WARPS * HKMhaPrefill[config].Q_BLOCK_SIZE)
D_FRAG_PER_LANEβ
comptime D_FRAG_PER_LANE = ((HKMhaPrefill[config].DEPTH * HKMhaPrefill[config].Q_BLOCK_SIZE) // 64)
DEPTHβ
comptime DEPTH = config.depth
k_swizzleβ
comptime k_swizzle = Optional(Swizzle(1, 1, 4))
k_swizzle2β
comptime k_swizzle2 = Optional(Swizzle(1, 0, 6))
KTileLoaderβ
comptime KTileLoader = SubTileLoaderLDS[DType.bfloat16, HKMhaPrefill[config].k_swizzle, HKMhaPrefill[config].k_swizzle2]
KV_BLOCKβ
comptime KV_BLOCK = config.kv_block
NUM_HEADSβ
comptime NUM_HEADS = config.num_heads
NUM_KV_HEADSβ
comptime NUM_KV_HEADS = config.num_kv_heads
NUM_THREADSβ
comptime NUM_THREADS = (HKMhaPrefill[config].NUM_WARPS * 64)
NUM_WARPSβ
comptime NUM_WARPS = config.num_warps
Q_BLOCK_SIZEβ
comptime Q_BLOCK_SIZE = config.q_block_size
RESCALE_THRESHOLDβ
comptime RESCALE_THRESHOLD = config.rescale_threshold
v_swizzleβ
comptime v_swizzle = Optional(None)
VTileLoaderβ
comptime VTileLoader = SubTileLoaderLDS_HK_st_8x32[DType.bfloat16, HKMhaPrefill[config].KV_BLOCK, HKMhaPrefill[config].DEPTH, 32, HKMhaPrefill[config].NUM_THREADS]
Methodsβ
load_qβ
static load_q[layout: TensorLayout](q_warp_2d: TileTensor[DType.bfloat16, layout, address_space=q_warp_2d.address_space, linear_idx_type=q_warp_2d.linear_idx_type, element_size=q_warp_2d.element_size]) -> TileTensor[DType.bfloat16, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
Loads the warp's Q sub-tile from gmem into a row_l register tile via RegTileLoader.
For d=128 / Q_BLOCK_SIZE=32: 8 buffer_load_bf16x8 per lane
(8 base tiles of MMA_K=16 cols each, distributed col_major[32, 2]
with 8 BF16 per lane per base tile).
Returns:
TileTensor[DType.bfloat16, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
runβ
static run[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, q_dtype: DType, output_dtype: DType, q_layout: TensorLayout, o_layout: TensorLayout](q: TileTensor[q_dtype, q_layout, ImmutAnyOrigin], k: k_t, v: v_t, o: TileTensor[output_dtype, o_layout, MutAnyOrigin], mask_functor: mask_t, scale: Float32, num_keys: Int, start_pos: Int)
Multi-block 8-warp MHA forward (inference-only).
Grid: (NUM_HEADS, ceildiv(seq_len, BM), batch). Each block
owns one (batch, head, BM-tile) slice; the 8 warps within
split the BM-tile's Q rows.
Expected layouts / shapes:
q,o:(batch, seq_len, NUM_HEADS, DEPTH)row-major TileTensor.o's dtype matchesconfig.output_dtypeβ BF16 for the production dispatcher (which holds a BF16 output buffer) or FP32 if the caller wants the unnormalized accumulator.k,v: anyMHAOperandwhoseblock_paged_tile[KV_BLOCK]returns(KV_BLOCK, DEPTH)tiles per(batch, t*KV_BLOCK, kv_head, 0).LayoutTensorMHAOperandfor contiguous test / bench buffers;KVCacheMHAOperandfor paged production caches (page_size >= KV_BLOCK = 64).
batch and seq_len / num_keys may be dynamic;
NUM_HEADS, NUM_KV_HEADS, DEPTH must be static.
NUM_HEADS must be a multiple of NUM_KV_HEADS
(GROUP = NUM_HEADS // NUM_KV_HEADS).
Args:
- βq (
TileTensor[q_dtype, q_layout, ImmutAnyOrigin]): Q tile tensor. - βk (
k_t): K operand (MHAOperand). - βv (
v_t): V operand (MHAOperand). - βo (
TileTensor[output_dtype, o_layout, MutAnyOrigin]): Output tile tensor (config.output_dtype, same shape asq). - βmask_functor (
mask_t): Per-tile mask predicate (causal, sliding-window, etc.). Evaluated inside the QKβsoftmax cluster; identity for unmasked attention. - βscale (
Float32): Softmax scale (typically1 / sqrt(DEPTH)). - βnum_keys (
Int): Runtime length of the K/V sequence. - βstart_pos (
Int): Position of the first Q row in the global sequence β non-zero for prefill chunks of a longer generation. Used by the mask functor to compute the causal cutoff.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!