IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

apply_qk_rms_norm

def apply_qk_rms_norm[in_dtype: DType, out_dtype: DType, //, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](q_out: TileTensor[out_dtype, address_space=q_out.address_space, linear_idx_type=q_out.linear_idx_type, element_size=q_out.element_size], k_out: TileTensor[out_dtype, address_space=k_out.address_space, linear_idx_type=k_out.linear_idx_type, element_size=k_out.element_size], gamma_q: TileTensor[DType.float32, address_space=gamma_q.address_space, linear_idx_type=gamma_q.linear_idx_type, element_size=gamma_q.element_size], gamma_k: TileTensor[DType.float32, address_space=gamma_k.address_space, linear_idx_type=gamma_k.linear_idx_type, element_size=gamma_k.element_size], qk_var: TileTensor[DType.float32, address_space=qk_var.address_space, linear_idx_type=qk_var.linear_idx_type, element_size=qk_var.element_size], q: TileTensor[in_dtype, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[in_dtype, address_space=k.address_space, linear_idx_type=k.linear_idx_type, element_size=k.element_size], epsilon: Float32, rows: Int, q_cols: Int, k_cols: Int, ctx: DeviceContext)

Fused per-element QK-RMSNorm apply for two operands Q and K.

Given the already cross-rank-reduced per-row statistics qk_var [M, 2] (col 0 = mean(q^2), col 1 = mean(k^2), float32) and per-column float32 scales gamma_q [Nq] / gamma_k [Nk], applies in a single launch:

q_out[m,c] = cast((cast(q[m,c], f32) * rsqrt(qk_var[m,0] + eps)) * gamma_q[c], out_dtype) and likewise for K with column 1. The grouping ((x * rs) * gamma) then cast matches the unfused graph this replaces for bit-accuracy. This fuses the QK-RMSNorm apply chain (~7 tiny elementwise/View kernels) into one launch, used for cross-head QK-RMSNorm under tensor parallelism.

All operands (q / k activations, the outputs q_out / k_out, and the gamma_q / gamma_k / qk_var inputs) are passed directly as TileTensors and loaded/stored in-kernel, matching the in-file rms_norm gamma.load[...] idiom.

Parameters:

  • ​in_dtype (DType): Element type of both activation inputs (bfloat16 or float32).
  • ​out_dtype (DType): Element type of the outputs (typically equal to in_dtype).
  • ​target (StringSlice[StaticConstantOrigin]): "cpu" or a GPU target string.

Args: