Mojo struct
HKMhaConfig
struct HKMhaConfig
Shape configuration for HKMhaPrefill.
Single source of truth for the shape parameters that drive
register-tile layouts, SMEM sub-block geometry, grid dimensions,
and IGLP scheduling. Lives in hk_mha_mma_op so both MhaMmaOp
and HKMhaPrefill can take it as a parameter without a circular
import.
Fieldsβ
- βq_block_size (
Int): Q rows per warp. - βkv_block (
Int): K/V rows per tile (64 at d=128). - βdepth (
Int): Head depth. - βnum_heads (
Int): Q num_heads. - βnum_kv_heads (
Int): K/V num_heads.1(full GQA) or equal tonum_heads(MHA); other ratios need a stride-aware DMA loader (TODO). - βnum_warps (
Int): Warps per block. - βrescale_threshold (
Float32): Lazy-rescale threshold in log2 units of the running max. Above this,o_reg/norm_vecare rescaled byexp2(max_prev - max_new); below this the rescale is deferred β the residual contribution fromatt_blockis bounded byexp2(-rescale_threshold)of the new max. - βoutput_dtype (
DType): Element dtype of the output tileo. FP32 by default; BF16 for production inference where the dispatcher holds a BF16 output buffer. The cast from the FP32 accumulator happens per-lane inside_store_o_to_gmem.
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
Methodsβ
__init__β
__init__(out self, *, q_block_size: Int, kv_block: Int, depth: Int, num_heads: Int, num_kv_heads: Int, num_warps: Int = 8, rescale_threshold: Float32 = 8, output_dtype: DType = DType.float32)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!