For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
apply_qk_rms_norm
def apply_qk_rms_norm[in_dtype: DType, out_dtype: DType, //, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](q_out: TileTensor[out_dtype, Storage=q_out.Storage, address_space=q_out.address_space, linear_idx_type=q_out.linear_idx_type, element_size=q_out.element_size], k_out: TileTensor[out_dtype, Storage=k_out.Storage, address_space=k_out.address_space, linear_idx_type=k_out.linear_idx_type, element_size=k_out.element_size], gamma_q: TileTensor[DType.float32, Storage=gamma_q.Storage, address_space=gamma_q.address_space, linear_idx_type=gamma_q.linear_idx_type, element_size=gamma_q.element_size], gamma_k: TileTensor[DType.float32, Storage=gamma_k.Storage, address_space=gamma_k.address_space, linear_idx_type=gamma_k.linear_idx_type, element_size=gamma_k.element_size], qk_var: TileTensor[DType.float32, Storage=qk_var.Storage, address_space=qk_var.address_space, linear_idx_type=qk_var.linear_idx_type, element_size=qk_var.element_size], q: TileTensor[in_dtype, Storage=q.Storage, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[in_dtype, Storage=k.Storage, address_space=k.address_space, linear_idx_type=k.linear_idx_type, element_size=k.element_size], epsilon: Float32, rows: Int, q_cols: Int, k_cols: Int, ctx: DeviceContext)
Fused per-element QK-RMSNorm apply for two operands Q and K.
Given the already cross-rank-reduced per-row statistics qk_var [M, 2]
(col 0 = mean(q^2), col 1 = mean(k^2), float32) and per-column float32
scales gamma_q [Nq] / gamma_k [Nk], applies in a single launch:
q_out[m,c] = cast((cast(q[m,c], f32) * rsqrt(qk_var[m,0] + eps)) * gamma_q[c], out_dtype)
and likewise for K with column 1. The grouping ((x * rs) * gamma) then
cast matches the unfused graph this replaces for bit-accuracy. This fuses
the QK-RMSNorm apply chain (~7 tiny elementwise/View kernels) into one
launch, used for cross-head QK-RMSNorm under tensor parallelism.
All operands (q / k activations, the outputs q_out / k_out, and the
gamma_q / gamma_k / qk_var inputs) are passed directly as
TileTensors and loaded/stored in-kernel, matching the in-file rms_norm
gamma.load[...] idiom.
Parameters:
- βin_dtype (
DType): Element type of both activation inputs (bfloat16orfloat32). - βout_dtype (
DType): Element type of the outputs (typically equal toin_dtype). - βtarget (
StringSlice[StaticConstantOrigin]):"cpu"or a GPU target string.
Args:
- βq_out (
TileTensor[out_dtype, Storage=q_out.Storage, address_space=q_out.address_space, linear_idx_type=q_out.linear_idx_type, element_size=q_out.element_size]): Scaled Q output, shape[M, Nq]. - βk_out (
TileTensor[out_dtype, Storage=k_out.Storage, address_space=k_out.address_space, linear_idx_type=k_out.linear_idx_type, element_size=k_out.element_size]): Scaled K output, shape[M, Nk]. - βgamma_q (
TileTensor[DType.float32, Storage=gamma_q.Storage, address_space=gamma_q.address_space, linear_idx_type=gamma_q.linear_idx_type, element_size=gamma_q.element_size]): Per-column float32 Q scales, shape[Nq]. - βgamma_k (
TileTensor[DType.float32, Storage=gamma_k.Storage, address_space=gamma_k.address_space, linear_idx_type=gamma_k.linear_idx_type, element_size=gamma_k.element_size]): Per-column float32 K scales, shape[Nk]. - βqk_var (
TileTensor[DType.float32, Storage=qk_var.Storage, address_space=qk_var.address_space, linear_idx_type=qk_var.linear_idx_type, element_size=qk_var.element_size]): Per-row float32 statistics, shape[M, 2](col 0 = mean(q^2), col 1 = mean(k^2)). - βq (
TileTensor[in_dtype, Storage=q.Storage, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size]): Q activations, shape[M, Nq]. - βk (
TileTensor[in_dtype, Storage=k.Storage, address_space=k.address_space, linear_idx_type=k.linear_idx_type, element_size=k.element_size]): K activations, shape[M, Nk]. - βepsilon (
Float32): RMSNorm epsilon, added to the variance beforersqrt. - βrows (
Int): Shared leading dimension of Q and K. - βq_cols (
Int): Number of columns of Q. - βk_cols (
Int): Number of columns of K. - βctx (
DeviceContext): Device context (ignored on CPU).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!