IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

gemv_partial_norm_kernel

def gemv_partial_norm_kernel[c_type: DType, a_type: DType, b_type: DType, normed_layout: TensorLayout, unnormed_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, gamma_layout: TensorLayout, TraceBufT: TraceBuf, //, simd_width: Int, tile_n: Int, num_threads: Int, enable_trace: Bool = False, pdl_level: PDLLevel = PDLLevel()](normed_output: TileTensor[c_type, normed_layout, MutAnyOrigin], unnormed_output: TileTensor[c_type, unnormed_layout, MutAnyOrigin], act: TileTensor[a_type, a_layout, ImmutAnyOrigin], weight: TileTensor[b_type, b_layout, ImmutAnyOrigin], gamma: TileTensor[a_type, gamma_layout, ImmutAnyOrigin], finish_counter: UnsafePointer[Int32, MutAnyOrigin], trace_buf: TraceBufT, eps: Scalar[a_type], n: Int, k: Int, n_normed: Int, num_normed_blocks: Int32)

Fused GEMV (M=1) + partial RMS norm, single launch.

Grid layout: (1, ceildiv(n, tile_n)). Each block computes one tile_n-wide column tile of y = act @ weight.T, writes to either normed_output or unnormed_output, and every normed block does a single fetch_add on finish_counter. The global-last arriver reads normed_output back, does a single- pass intra-block RMS reduction, applies gamma in place, and resets the counter.

Constraints:

  • n_normed must be divisible by tile_n.
  • n_normed must be divisible by simd_width (apply-norm uses vectorized loads).

Parameters:

  • ​c_type (DType): Output dtype (of normed_output and unnormed_output).
  • ​a_type (DType): Activation / gamma dtype.
  • ​b_type (DType): Weight dtype.
  • ​normed_layout (TensorLayout): Layout of normed_output.
  • ​unnormed_layout (TensorLayout): Layout of unnormed_output.
  • ​a_layout (TensorLayout): Layout of act.
  • ​b_layout (TensorLayout): Layout of weight.
  • ​gamma_layout (TensorLayout): Layout of gamma.
  • ​TraceBufT (TraceBuf): Trace-buffer implementation (NullTrace or GmemTrace). Pass NullTrace for zero-overhead untraced runs.
  • ​simd_width (Int): Vectorization width for K-loop loads and apply-norm.
  • ​tile_n (Int): Columns of y computed per block.
  • ​num_threads (Int): Block size (threads per block).
  • ​enable_trace (Bool): When True, record per-phase timestamps into trace_buf. When False (default), all record sites compile away to zero PTX.
  • ​pdl_level (PDLLevel): Programmatic Dependent Launch level for chaining with upstream/downstream kernels.

Args: