For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
row_mean_of_squares_qk
def row_mean_of_squares_qk[in_dtype: DType, out_dtype: DType, //, target: StringSlice[StaticConstantOrigin] = StringSlice("cpu")](output: TileTensor[out_dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], q: TileTensor[in_dtype, Storage=q.Storage, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size], k: TileTensor[in_dtype, Storage=k.Storage, address_space=k.address_space, linear_idx_type=k.linear_idx_type, element_size=k.element_size], rows: Int, q_cols: Int, k_cols: Int, ctx: DeviceContext)
Fused per-row mean of squares for two operands Q and K.
Computes out[m, 0] = sum_n(q[m,n]^2) / q_cols and
out[m, 1] = sum_n(k[m,n]^2) / k_cols, accumulated in accum_type. Q and K
share the leading rows dimension but may have different column counts. This
is a single-launch fusion of two row_mean_of_squares reductions, used for
cross-head QK-RMSNorm statistics under tensor parallelism.
All operands (q [M, Nq], k [M, Nk], and the [M, 2] output) are passed
directly as TileTensors and loaded/stored in-kernel.
Parameters:
- βin_dtype (
DType): Element type of both inputs (e.g.bfloat16orfloat32). - βout_dtype (
DType): Element type of the per-row result (typicallyfloat32). - βtarget (
StringSlice[StaticConstantOrigin]):"cpu"or a GPU target string.
Args:
- βoutput (
TileTensor[out_dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Per-row result, shape[M, 2](col 0 = mean(q^2), col 1 = mean(k^2)). - βq (
TileTensor[in_dtype, Storage=q.Storage, address_space=q.address_space, linear_idx_type=q.linear_idx_type, element_size=q.element_size]): Q activations, shape[M, Nq]. - βk (
TileTensor[in_dtype, Storage=k.Storage, address_space=k.address_space, linear_idx_type=k.linear_idx_type, element_size=k.element_size]): K activations, shape[M, Nk]. - βrows (
Int): Shared leading dimension of Q and K. - βq_cols (
Int): Number of columns reduced for Q. - βk_cols (
Int): Number of columns reduced for K. - βctx (
DeviceContext): Device context (ignored on CPU).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!