Skip to main content

Mojo module

gemv_partial_norm

Fused matvec (M=1) + partial RMS norm on B200.

Given x [1, K], W [N, K] (with transpose_b=True), gamma [N_normed], and eps, this module computes:

y        = x @ W.T                             # [1, N]
normed   = rms_norm(y[:, :N_normed], gamma, eps)
unnormed = y[:, N_normed:]

The fused kernel does this in a single launch. Every normed block does a single device-scope acq_rel fetch_add on finish_counter: the release half orders each block's prior normed_output writes before the counter increment, and the acquire half in the global-last arriver (prev == num_normed_blocks - 1) makes every peer's writes visible so it can read normed_output back and do a single-pass intra-block RMS reduction + apply with gamma.

The unfused path is a 2-launch baseline using existing primitives (matmul followed by rms_norm_gpu). The matmul writes the full [M, N] output to a caller-provided y_scratch buffer; the RMS norm then reads y_scratch[:, :N_normed] and writes the normed values to normed_output. The unnormed tail lives as a view into y_scratch[:, N_normed:], matching how model code naturally expresses this.

comptime values

GEMV_TRACE_EVENTS_PER_BLOCK

comptime GEMV_TRACE_EVENTS_PER_BLOCK = 16

Number of UInt64 timestamp slots reserved per block in a trace buffer. Only 10 slots are used today (roles 0 through 9); slots 10 through 15 are reserved for future per-iteration instrumentation.

Structs

  • GmemTrace: HBM-backed trace buffer: store(offset, ts) writes ts to ptr[offset]. 8 bytes of kernel arg.
  • NullTrace: Zero-sized no-op trace buffer. store is pass; struct has no fields so it contributes 0 kernel-arg bytes when passed as an argument.

Traits

  • TraceBuf: Trace-buffer interface. Implementations: NullTrace, GmemTrace.

Functions

Was this page helpful?