Mojo module
gemv_partial_norm
Fused matvec (M=1) + partial RMS norm on B200.
Given x [1, K], W [N, K] (with transpose_b=True), gamma [N_normed], and eps, this module computes:
y = x @ W.T # [1, N]
normed = rms_norm(y[:, :N_normed], gamma, eps)
unnormed = y[:, N_normed:]The fused kernel does this in a single launch. Every normed block
does a single device-scope acq_rel fetch_add on finish_counter:
the release half orders each block's prior normed_output writes
before the counter increment, and the acquire half in the
global-last arriver (prev == num_normed_blocks - 1) makes every
peer's writes visible so it can read normed_output back and do a
single-pass intra-block RMS reduction + apply with gamma.
The unfused path is a 2-launch baseline using existing primitives
(matmul followed by rms_norm_gpu). The matmul writes the full
[M, N] output to a caller-provided y_scratch buffer; the RMS
norm then reads y_scratch[:, :N_normed] and writes the normed
values to normed_output. The unnormed tail lives as a view into
y_scratch[:, N_normed:], matching how model code naturally
expresses this.
comptime values
GEMV_TRACE_EVENTS_PER_BLOCK
comptime GEMV_TRACE_EVENTS_PER_BLOCK = 16
Number of UInt64 timestamp slots reserved per block in a trace buffer. Only 10 slots are used today (roles 0 through 9); slots 10 through 15 are reserved for future per-iteration instrumentation.
Structs
-
GmemTrace: HBM-backed trace buffer:store(offset, ts)writeststoptr[offset]. 8 bytes of kernel arg. -
NullTrace: Zero-sized no-op trace buffer.storeispass; struct has no fields so it contributes 0 kernel-arg bytes when passed as an argument.
Traits
-
TraceBuf: Trace-buffer interface. Implementations:NullTrace,GmemTrace.
Functions
-
gemv_and_partial_norm: Computesy = act @ weight.T, then partitionsyinto a normed front and an unnormed tail. -
gemv_and_partial_norm_unfused_with_scratch: Unfused 2-launch path with caller-provided y scratch. -
gemv_and_partial_norm_with_scratch: Fused path with caller-provided scratch. -
gemv_partial_norm_kernel: Fused GEMV (M=1) + partial RMS norm, single launch.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!