Skip to main content

Mojo struct

HKMhaConfig

struct HKMhaConfig

Shape configuration for HKMhaPrefill.

Single source of truth for the shape parameters that drive register-tile layouts, SMEM sub-block geometry, grid dimensions, and IGLP scheduling. Lives in hk_mha_mma_op so both MhaMmaOp and HKMhaPrefill can take it as a parameter without a circular import.

Fields​

  • ​q_block_size (Int): Q rows per warp.
  • ​kv_block (Int): K/V rows per tile (64 at d=128).
  • ​depth (Int): Head depth.
  • ​num_heads (Int): Q num_heads.
  • ​num_kv_heads (Int): K/V num_heads. 1 (full GQA) or equal to num_heads (MHA); other ratios need a stride-aware DMA loader (TODO).
  • ​num_warps (Int): Warps per block.
  • ​rescale_threshold (Float32): Lazy-rescale threshold in log2 units of the running max. Above this, o_reg / norm_vec are rescaled by exp2(max_prev - max_new); below this the rescale is deferred β€” the residual contribution from att_block is bounded by exp2(-rescale_threshold) of the new max.
  • ​output_dtype (DType): Element dtype of the output tile o. FP32 by default; BF16 for production inference where the dispatcher holds a BF16 output buffer. The cast from the FP32 accumulator happens per-lane inside _store_o_to_gmem.

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

Methods​

__init__​

__init__(out self, *, q_block_size: Int, kv_block: Int, depth: Int, num_heads: Int, num_kv_heads: Int, num_warps: Int = 8, rescale_threshold: Float32 = 8, output_dtype: DType = DType.float32)