For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
ApplyQKRMSNorm
struct ApplyQKRMSNorm
Fused per-element QK-RMSNorm apply for two operands Q and K.
Given the already cross-rank-reduced per-row statistics qk_var [M, 2]
(col 0 = mean(q^2), col 1 = mean(k^2), float32) and per-column float32
scales gamma_q [Nq] / gamma_k [Nk], applies in a single launch:
q_out[m,c] = cast((cast(q[m,c], f32) * rsqrt(qk_var[m,0] + eps)) * gamma_q[c], q.dtype)
and likewise for K with column 1. The grouping ((x * rs) * gamma) then
cast matches the unfused graph this replaces for bit-accuracy. This fuses
the QK-RMSNorm apply chain (~7 tiny elementwise/View kernels) into one
launch, used for cross-head QK-RMSNorm under tensor parallelism.
Outputs (in order): q_out [M, Nq], k_out [M, Nk] (both q/k dtype).
Inputs (in order): q [M, Nq], k [M, Nk] (activation dtype),
qk_var [M, 2] (float32), gamma_q [Nq] (float32),
gamma_k [Nk] (float32). Attribute: epsilon (float32 host scalar).
Implemented traitsโ
Methodsโ
executeโ
static def execute[target: StringSlice[StaticConstantOrigin]](q_out: ManagedTensorSlice[Output, static_spec=q_out.static_spec], k_out: ManagedTensorSlice[Output, static_spec=k_out.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], k: ManagedTensorSlice[Input, static_spec=k.static_spec], qk_var: ManagedTensorSlice[Input, static_spec=qk_var.static_spec], gamma_q: ManagedTensorSlice[Input, static_spec=gamma_q.static_spec], gamma_k: ManagedTensorSlice[Input, static_spec=gamma_k.static_spec], epsilon: Float32, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!