For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
apply_qk_rms_norm_gpu_block
def apply_qk_rms_norm_gpu_block[in_dtype: DType, out_dtype: DType, q_out_mut: Bool, q_out_layout: TensorLayout, q_out_origin: Origin[mut=q_out_mut], k_out_mut: Bool, k_out_layout: TensorLayout, k_out_origin: Origin[mut=k_out_mut], gamma_q_mut: Bool, gamma_q_layout: TensorLayout, gamma_q_origin: Origin[mut=gamma_q_mut], gamma_k_mut: Bool, gamma_k_layout: TensorLayout, gamma_k_origin: Origin[mut=gamma_k_mut], var_mut: Bool, var_layout: TensorLayout, var_origin: Origin[mut=var_mut], q_layout: TensorLayout, q_origin: Origin[mut=q_origin.mut], k_layout: TensorLayout, k_origin: Origin[mut=k_origin.mut], //, simd_width: Int](q_out: TileTensor[out_dtype, q_out_layout, q_out_origin], k_out: TileTensor[out_dtype, k_out_layout, k_out_origin], gamma_q: TileTensor[DType.float32, gamma_q_layout, gamma_q_origin], gamma_k: TileTensor[DType.float32, gamma_k_layout, gamma_k_origin], qk_var: TileTensor[DType.float32, var_layout, var_origin], q: TileTensor[in_dtype, q_layout, q_origin], k: TileTensor[in_dtype, k_layout, k_origin], epsilon: Float32, q_cols: Int, k_cols: Int) where q_out_mut and k_out_mut
Fused per-element QK-RMSNorm apply for Q and K in a single launch.
The grid is 2D: block_idx.x selects the row and block_idx.y selects the
operand (0 = Q, 1 = K). Each block owns one (row, operand) and threads
grid-stride across that operand's columns, applying ((x * rs) * gamma).
All operands (q [M, Nq], k [M, Nk], gamma_q [Nq], gamma_k [Nk],
qk_var [M, 2], and the outputs q_out [M, Nq] / k_out [M, Nk]) are
loaded/stored directly from their TileTensors, matching the in-file
rms_norm gamma.load[...] idiom.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!