IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

apply_qk_rms_norm_gpu_block

def apply_qk_rms_norm_gpu_block[in_dtype: DType, out_dtype: DType, q_out_mut: Bool, q_out_layout: TensorLayout, q_out_origin: Origin[mut=q_out_mut], k_out_mut: Bool, k_out_layout: TensorLayout, k_out_origin: Origin[mut=k_out_mut], gamma_q_mut: Bool, gamma_q_layout: TensorLayout, gamma_q_origin: Origin[mut=gamma_q_mut], gamma_k_mut: Bool, gamma_k_layout: TensorLayout, gamma_k_origin: Origin[mut=gamma_k_mut], var_mut: Bool, var_layout: TensorLayout, var_origin: Origin[mut=var_mut], q_layout: TensorLayout, q_origin: Origin[mut=q_origin.mut], k_layout: TensorLayout, k_origin: Origin[mut=k_origin.mut], //, simd_width: Int](q_out: TileTensor[out_dtype, q_out_layout, q_out_origin], k_out: TileTensor[out_dtype, k_out_layout, k_out_origin], gamma_q: TileTensor[DType.float32, gamma_q_layout, gamma_q_origin], gamma_k: TileTensor[DType.float32, gamma_k_layout, gamma_k_origin], qk_var: TileTensor[DType.float32, var_layout, var_origin], q: TileTensor[in_dtype, q_layout, q_origin], k: TileTensor[in_dtype, k_layout, k_origin], epsilon: Float32, q_cols: Int, k_cols: Int) where q_out_mut and k_out_mut

Fused per-element QK-RMSNorm apply for Q and K in a single launch.

The grid is 2D: block_idx.x selects the row and block_idx.y selects the operand (0 = Q, 1 = K). Each block owns one (row, operand) and threads grid-stride across that operand's columns, applying ((x * rs) * gamma). All operands (q [M, Nq], k [M, Nk], gamma_q [Nq], gamma_k [Nk], qk_var [M, 2], and the outputs q_out [M, Nq] / k_out [M, Nk]) are loaded/stored directly from their TileTensors, matching the in-file rms_norm gamma.load[...] idiom.