Skip to main content

Mojo function

gemv_partial_norm_kernel

gemv_partial_norm_kernel[c_type: DType, a_type: DType, b_type: DType, normed_layout: TensorLayout, unnormed_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, gamma_layout: TensorLayout, TraceBufT: TraceBuf, //, simd_width: Int, tile_n: Int, num_threads: Int, enable_trace: Bool = False, pdl_level: PDLLevel = PDLLevel()](normed_output: TileTensor[c_type, normed_layout, MutAnyOrigin], unnormed_output: TileTensor[c_type, unnormed_layout, MutAnyOrigin], act: TileTensor[a_type, a_layout, MutAnyOrigin], weight: TileTensor[b_type, b_layout, MutAnyOrigin], gamma: TileTensor[a_type, gamma_layout, MutAnyOrigin], finish_counter: UnsafePointer[Int32, MutAnyOrigin], trace_buf: TraceBufT, eps: Scalar[a_type], n: Int, k: Int, n_normed: Int, num_normed_blocks: Int32)

Fused GEMV (M=1) + partial RMS norm, single launch.

Grid layout: (1, ceildiv(n, tile_n)). Each block computes one tile_n-wide column tile of y = act @ weight.T, writes to either normed_output or unnormed_output, and every normed block does a single fetch_add on finish_counter. The global-last arriver reads normed_output back, does a single- pass intra-block RMS reduction, applies gamma in place, and resets the counter.

Constraints:

  • n_normed must be divisible by tile_n.
  • n_normed must be divisible by simd_width (apply-norm uses vectorized loads).

Parameters:

  • c_type (DType): Output dtype (of normed_output and unnormed_output).
  • a_type (DType): Activation / gamma dtype.
  • b_type (DType): Weight dtype.
  • normed_layout (TensorLayout): Layout of normed_output.
  • unnormed_layout (TensorLayout): Layout of unnormed_output.
  • a_layout (TensorLayout): Layout of act.
  • b_layout (TensorLayout): Layout of weight.
  • gamma_layout (TensorLayout): Layout of gamma.
  • TraceBufT (TraceBuf): Trace-buffer implementation (NullTrace or GmemTrace). Pass NullTrace for zero-overhead untraced runs.
  • simd_width (Int): Vectorization width for K-loop loads and apply-norm.
  • tile_n (Int): Columns of y computed per block.
  • num_threads (Int): Block size (threads per block).
  • enable_trace (Bool): When True, record per-phase timestamps into trace_buf. When False (default), all record sites compile away to zero PTX.
  • pdl_level (PDLLevel): Programmatic Dependent Launch level for chaining with upstream/downstream kernels.

Args:

  • normed_output (TileTensor): [M, N_normed] output buffer. The global-last arriver rewrites this in place after the RMS reduction has been applied.
  • unnormed_output (TileTensor): [M, N - N_normed] output buffer. Written exactly once by each unnormed block.
  • act (TileTensor): [M, K] activations.
  • weight (TileTensor): [N, K] weights (used as weight.T).
  • gamma (TileTensor): [N_normed] RMS norm scale.
  • finish_counter (UnsafePointer): Single Int32 counter used for the flat grid sync. Must be zero-initialized on first use. The kernel resets it to zero before returning.
  • trace_buf (TraceBufT): Instance of TraceBufT carrying the device-side trace-buffer pointer (or a zero-sized no-op).
  • eps (Scalar): RMS norm epsilon.
  • n (Int): Full output width N (normed + unnormed).
  • k (Int): Activation / weight inner dimension.
  • n_normed (Int): Length of the normed region.
  • num_normed_blocks (Int32): n_normed / tile_n, used by the global- last election (prev == num_normed_blocks - 1).

Was this page helpful?