Mojo function
gemv_partial_norm_kernel
gemv_partial_norm_kernel[c_type: DType, a_type: DType, b_type: DType, normed_layout: TensorLayout, unnormed_layout: TensorLayout, a_layout: TensorLayout, b_layout: TensorLayout, gamma_layout: TensorLayout, TraceBufT: TraceBuf, //, simd_width: Int, tile_n: Int, num_threads: Int, enable_trace: Bool = False, pdl_level: PDLLevel = PDLLevel()](normed_output: TileTensor[c_type, normed_layout, MutAnyOrigin], unnormed_output: TileTensor[c_type, unnormed_layout, MutAnyOrigin], act: TileTensor[a_type, a_layout, MutAnyOrigin], weight: TileTensor[b_type, b_layout, MutAnyOrigin], gamma: TileTensor[a_type, gamma_layout, MutAnyOrigin], finish_counter: UnsafePointer[Int32, MutAnyOrigin], trace_buf: TraceBufT, eps: Scalar[a_type], n: Int, k: Int, n_normed: Int, num_normed_blocks: Int32)
Fused GEMV (M=1) + partial RMS norm, single launch.
Grid layout: (1, ceildiv(n, tile_n)). Each block computes one
tile_n-wide column tile of y = act @ weight.T, writes to
either normed_output or unnormed_output, and every normed
block does a single fetch_add on finish_counter. The
global-last arriver reads normed_output back, does a single-
pass intra-block RMS reduction, applies gamma in place, and
resets the counter.
Constraints:
n_normedmust be divisible bytile_n.n_normedmust be divisible bysimd_width(apply-norm uses vectorized loads).
Parameters:
- c_type (
DType): Output dtype (ofnormed_outputandunnormed_output). - a_type (
DType): Activation / gamma dtype. - b_type (
DType): Weight dtype. - normed_layout (
TensorLayout): Layout ofnormed_output. - unnormed_layout (
TensorLayout): Layout ofunnormed_output. - a_layout (
TensorLayout): Layout ofact. - b_layout (
TensorLayout): Layout ofweight. - gamma_layout (
TensorLayout): Layout ofgamma. - TraceBufT (
TraceBuf): Trace-buffer implementation (NullTraceorGmemTrace). PassNullTracefor zero-overhead untraced runs. - simd_width (
Int): Vectorization width for K-loop loads and apply-norm. - tile_n (
Int): Columns ofycomputed per block. - num_threads (
Int): Block size (threads per block). - enable_trace (
Bool): WhenTrue, record per-phase timestamps intotrace_buf. WhenFalse(default), all record sites compile away to zero PTX. - pdl_level (
PDLLevel): Programmatic Dependent Launch level for chaining with upstream/downstream kernels.
Args:
- normed_output (
TileTensor):[M, N_normed]output buffer. The global-last arriver rewrites this in place after the RMS reduction has been applied. - unnormed_output (
TileTensor):[M, N - N_normed]output buffer. Written exactly once by each unnormed block. - act (
TileTensor):[M, K]activations. - weight (
TileTensor):[N, K]weights (used asweight.T). - gamma (
TileTensor):[N_normed]RMS norm scale. - finish_counter (
UnsafePointer): SingleInt32counter used for the flat grid sync. Must be zero-initialized on first use. The kernel resets it to zero before returning. - trace_buf (
TraceBufT): Instance ofTraceBufTcarrying the device-side trace-buffer pointer (or a zero-sized no-op). - eps (
Scalar): RMS norm epsilon. - n (
Int): Full output widthN(normed + unnormed). - k (
Int): Activation / weight inner dimension. - n_normed (
Int): Length of the normed region. - num_normed_blocks (
Int32):n_normed / tile_n, used by the global- last election (prev == num_normed_blocks - 1).
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!