IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo module

gemv_partial_norm

Fused matvec (M=1) + partial RMS norm on B200.

Given x [1, K], W [N, K] (with transpose_b=True), gamma [N_normed], and eps, this module computes:

y        = x @ W.T                             # [1, N]
normed   = rms_norm(y[:, :N_normed], gamma, eps)
unnormed = y[:, N_normed:]

The fused kernel does this in a single launch. Every normed block does a single device-scope acq_rel fetch_add on finish_counter: the release half orders each block's prior normed_output writes before the counter increment, and the acquire half in the global-last arriver (prev == num_normed_blocks - 1) makes every peer's writes visible so it can read normed_output back and do a single-pass intra-block RMS reduction + apply with gamma.

The unfused path is a 2-launch baseline using existing primitives (matmul followed by rms_norm_gpu). The matmul writes the full [M, N] output to a caller-provided y_scratch buffer; the RMS norm then reads y_scratch[:, :N_normed] and writes the normed values to normed_output. The unnormed tail lives as a view into y_scratch[:, N_normed:], matching how model code naturally expresses this.

comptime values​

GEMV_TRACE_EVENTS_PER_BLOCK​

comptime GEMV_TRACE_EVENTS_PER_BLOCK = 16

Number of UInt64 timestamp slots reserved per block in a trace buffer. Only 10 slots are used today (roles 0 through 9); slots 10 through 15 are reserved for future per-iteration instrumentation.

Functions​