Skip to main content

Mojo struct

HKMhaPrefill

struct HKMhaPrefill[config: HKMhaConfig]

8-warp MHA forward kernel parameterized by HKMhaConfig.

Each block runs config.num_warps wave64 warps that share K/V SMEM via cooperative DMA. Warp w owns Q rows [w * q_block_size, (w + 1) * q_block_size) of the block's stripe and carries its own register-resident attention state.

Parameters​

  • ​config (HKMhaConfig): Shape configuration (HKMhaConfig).

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

BM​

comptime BM = (HKMhaPrefill[config].NUM_WARPS * HKMhaPrefill[config].Q_BLOCK_SIZE)

D_FRAG_PER_LANE​

comptime D_FRAG_PER_LANE = ((HKMhaPrefill[config].DEPTH * HKMhaPrefill[config].Q_BLOCK_SIZE) // 64)

DEPTH​

comptime DEPTH = config.depth

k_swizzle​

comptime k_swizzle = Optional(Swizzle(1, 1, 4))

k_swizzle2​

comptime k_swizzle2 = Optional(Swizzle(1, 0, 6))

KTileLoader​

comptime KTileLoader = SubTileLoaderLDS[DType.bfloat16, HKMhaPrefill[config].k_swizzle, HKMhaPrefill[config].k_swizzle2]

KV_BLOCK​

comptime KV_BLOCK = config.kv_block

NUM_HEADS​

comptime NUM_HEADS = config.num_heads

NUM_KV_HEADS​

comptime NUM_KV_HEADS = config.num_kv_heads

NUM_THREADS​

comptime NUM_THREADS = (HKMhaPrefill[config].NUM_WARPS * 64)

NUM_WARPS​

comptime NUM_WARPS = config.num_warps

Q_BLOCK_SIZE​

comptime Q_BLOCK_SIZE = config.q_block_size

RESCALE_THRESHOLD​

comptime RESCALE_THRESHOLD = config.rescale_threshold

v_swizzle​

comptime v_swizzle = Optional(None)

VTileLoader​

comptime VTileLoader = SubTileLoaderLDS_HK_st_8x32[DType.bfloat16, HKMhaPrefill[config].KV_BLOCK, HKMhaPrefill[config].DEPTH, 32, HKMhaPrefill[config].NUM_THREADS]

Methods​

load_q​

static load_q[layout: TensorLayout](q_warp_2d: TileTensor[DType.bfloat16, layout, address_space=q_warp_2d.address_space, linear_idx_type=q_warp_2d.linear_idx_type, element_size=q_warp_2d.element_size]) -> TileTensor[DType.bfloat16, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

Loads the warp's Q sub-tile from gmem into a row_l register tile via RegTileLoader.

For d=128 / Q_BLOCK_SIZE=32: 8 buffer_load_bf16x8 per lane (8 base tiles of MMA_K=16 cols each, distributed col_major[32, 2] with 8 BF16 per lane per base tile).

Returns:

TileTensor[DType.bfloat16, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]

run​

static run[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, q_dtype: DType, output_dtype: DType, q_layout: TensorLayout, o_layout: TensorLayout](q: TileTensor[q_dtype, q_layout, ImmutAnyOrigin], k: k_t, v: v_t, o: TileTensor[output_dtype, o_layout, MutAnyOrigin], mask_functor: mask_t, scale: Float32, num_keys: Int, start_pos: Int)

Multi-block 8-warp MHA forward (inference-only).

Grid: (NUM_HEADS, ceildiv(seq_len, BM), batch). Each block owns one (batch, head, BM-tile) slice; the 8 warps within split the BM-tile's Q rows.

Expected layouts / shapes:

  • q, o: (batch, seq_len, NUM_HEADS, DEPTH) row-major TileTensor. o's dtype matches config.output_dtype β€” BF16 for the production dispatcher (which holds a BF16 output buffer) or FP32 if the caller wants the unnormalized accumulator.
  • k, v: any MHAOperand whose block_paged_tile[KV_BLOCK] returns (KV_BLOCK, DEPTH) tiles per (batch, t*KV_BLOCK, kv_head, 0). LayoutTensorMHAOperand for contiguous test / bench buffers; KVCacheMHAOperand for paged production caches (page_size >= KV_BLOCK = 64).

batch and seq_len / num_keys may be dynamic; NUM_HEADS, NUM_KV_HEADS, DEPTH must be static. NUM_HEADS must be a multiple of NUM_KV_HEADS (GROUP = NUM_HEADS // NUM_KV_HEADS).

Args:

  • ​q (TileTensor[q_dtype, q_layout, ImmutAnyOrigin]): Q tile tensor.
  • ​k (k_t): K operand (MHAOperand).
  • ​v (v_t): V operand (MHAOperand).
  • ​o (TileTensor[output_dtype, o_layout, MutAnyOrigin]): Output tile tensor (config.output_dtype, same shape as q).
  • ​mask_functor (mask_t): Per-tile mask predicate (causal, sliding-window, etc.). Evaluated inside the QKβ†’softmax cluster; identity for unmasked attention.
  • ​scale (Float32): Softmax scale (typically 1 / sqrt(DEPTH)).
  • ​num_keys (Int): Runtime length of the K/V sequence.
  • ​start_pos (Int): Position of the first Q row in the global sequence β€” non-zero for prefill chunks of a longer generation. Used by the mask functor to compute the causal cutoff.