IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

ApplyQKRMSNorm

struct ApplyQKRMSNorm

Fused per-element QK-RMSNorm apply for two operands Q and K.

Given the already cross-rank-reduced per-row statistics qk_var [M, 2] (col 0 = mean(q^2), col 1 = mean(k^2), float32) and per-column float32 scales gamma_q [Nq] / gamma_k [Nk], applies in a single launch:

q_out[m,c] = cast((cast(q[m,c], f32) * rsqrt(qk_var[m,0] + eps)) * gamma_q[c], q.dtype) and likewise for K with column 1. The grouping ((x * rs) * gamma) then cast matches the unfused graph this replaces for bit-accuracy. This fuses the QK-RMSNorm apply chain (~7 tiny elementwise/View kernels) into one launch, used for cross-head QK-RMSNorm under tensor parallelism.

Outputs (in order): q_out [M, Nq], k_out [M, Nk] (both q/k dtype). Inputs (in order): q [M, Nq], k [M, Nk] (activation dtype), qk_var [M, 2] (float32), gamma_q [Nq] (float32), gamma_k [Nk] (float32). Attribute: epsilon (float32 host scalar).

Implemented traitsโ€‹

AnyType, ImplicitlyDeletable

Methodsโ€‹

executeโ€‹

static def execute[target: StringSlice[StaticConstantOrigin]](q_out: ManagedTensorSlice[Output, static_spec=q_out.static_spec], k_out: ManagedTensorSlice[Output, static_spec=k_out.static_spec], q: ManagedTensorSlice[Input, static_spec=q.static_spec], k: ManagedTensorSlice[Input, static_spec=k.static_spec], qk_var: ManagedTensorSlice[Input, static_spec=qk_var.static_spec], gamma_q: ManagedTensorSlice[Input, static_spec=gamma_q.static_spec], gamma_k: ManagedTensorSlice[Input, static_spec=gamma_k.static_spec], epsilon: Float32, ctx: DeviceContext)