Skip to main content

Mojo function

gemv_partial_norm_kernel

gemv_partial_norm_kernel[c_type: DType, a_type: DType, b_type: DType, normed_layout: TensorLayout, unnormed_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, gamma_layout: TensorLayout, TraceBufT: TraceBuf, //, simd_width: Int, tile_n: Int, num_threads: Int, enable_trace: Bool = False, pdl_level: PDLLevel = PDLLevel()](normed_output: TileTensor[c_type, normed_layout, MutAnyOrigin], unnormed_output: TileTensor[c_type, unnormed_layout, MutAnyOrigin], act: TileTensor[a_type, a_layout, MutAnyOrigin], weight: TileTensor[b_type, b_layout, MutAnyOrigin], gamma: TileTensor[a_type, gamma_layout, MutAnyOrigin], finish_counter: UnsafePointer[Int32, MutAnyOrigin], trace_buf: TraceBufT, eps: Scalar[a_type], n: Int, k: Int, n_normed: Int, num_normed_blocks: Int32)

Fused GEMV (M=1) + partial RMS norm, single launch.

Grid layout: (1, ceildiv(n, tile_n)). Each block computes one tile_n-wide column tile of y = act @ weight.T, writes to either normed_output or unnormed_output, and every normed block does a single fetch_add on finish_counter. The global-last arriver reads normed_output back, does a single- pass intra-block RMS reduction, applies gamma in place, and resets the counter.

Constraints:

  • n_normed must be divisible by tile_n.
  • n_normed must be divisible by simd_width (apply-norm uses vectorized loads).

Parameters:

  • ​c_type (DType): Output dtype (of normed_output and unnormed_output).
  • ​a_type (DType): Activation / gamma dtype.
  • ​b_type (DType): Weight dtype.
  • ​normed_layout (TensorLayout): Layout of normed_output.
  • ​unnormed_layout (TensorLayout): Layout of unnormed_output.
  • ​a_layout (TensorLayout): Layout of act.
  • ​b_layout (TensorLayout): Layout of weight.
  • ​gamma_layout (TensorLayout): Layout of gamma.
  • ​TraceBufT (TraceBuf): Trace-buffer implementation (NullTrace or GmemTrace). Pass NullTrace for zero-overhead untraced runs.
  • ​simd_width (Int): Vectorization width for K-loop loads and apply-norm.
  • ​tile_n (Int): Columns of y computed per block.
  • ​num_threads (Int): Block size (threads per block).
  • ​enable_trace (Bool): When True, record per-phase timestamps into trace_buf. When False (default), all record sites compile away to zero PTX.
  • ​pdl_level (PDLLevel): Programmatic Dependent Launch level for chaining with upstream/downstream kernels.

Args: