IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

row_mean_of_squares_qk

def row_mean_of_squares_qk[in_dtype: DType, out_dtype: DType, //, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](output: TileTensor[out_dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[in_dtype, Storage=q.Storage, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[in_dtype, Storage=k.Storage, address_space=k.address_space, linear_idx_type=k.linear_idx_type, element_size=k.element_size], rows: Int, q_cols: Int, k_cols: Int, ctx: DeviceContext)

Fused per-row mean of squares for two operands Q and K.

Computes out[m, 0] = sum_n(q[m,n]^2) / q_cols and out[m, 1] = sum_n(k[m,n]^2) / k_cols, accumulated in accum_type. Q and K share the leading rows dimension but may have different column counts. This is a single-launch fusion of two row_mean_of_squares reductions, used for cross-head QK-RMSNorm statistics under tensor parallelism.

All operands (q [M, Nq], k [M, Nk], and the [M, 2] output) are passed directly as TileTensors and loaded/stored in-kernel.

Parameters:

  • ​in_dtype (DType): Element type of both inputs (e.g. bfloat16 or float32).
  • ​out_dtype (DType): Element type of the per-row result (typically float32).
  • ​target (StringSlice[StaticConstantOrigin]): "cpu" or a GPU target string.

Args: