Skip to main content

Mojo function

gemv_and_partial_norm_with_scratch

gemv_and_partial_norm_with_scratch[c_type: DType, a_type: DType, TraceBufT: TraceBuf = NullTrace, //, *, transpose_b: Bool = True, pdl_level: PDLLevel = PDLLevel(), tile_n: Int = 4, num_threads: Int = 256, enable_trace: Bool = False](normed_output: TileTensor[c_type, normed_output.LayoutType, normed_output.origin, address_space=normed_output.address_space, linear_idx_type=normed_output.linear_idx_type, element_size=normed_output.element_size], unnormed_output: TileTensor[c_type, unnormed_output.LayoutType, unnormed_output.origin, address_space=unnormed_output.address_space, linear_idx_type=unnormed_output.linear_idx_type, element_size=unnormed_output.element_size], act: TileTensor[a_type, act.LayoutType, act.origin, address_space=act.address_space, linear_idx_type=act.linear_idx_type, element_size=act.element_size], weight: TileTensor[a_type, weight.LayoutType, weight.origin, address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size], gamma: TileTensor[a_type, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], eps: Scalar[a_type], finish_counter: UnsafePointer[Int32, MutAnyOrigin], ctx: DeviceContext, trace_buf: TraceBufT = NullTrace())

Fused path with caller-provided scratch.

finish_counter must be zero-initialized on first use. The kernel resets it to zero before returning so the same buffer can be reused across calls without an external memset.

Set enable_trace=True and pass a GmemTrace(ptr) to record per- block timestamps into ptr (sized num_normed_blocks * GEMV_TRACE_EVENTS_PER_BLOCK u64s). When disabled (default), the trace path dead-code-eliminates, yielding byte-identical PTX to the untraced kernel.

Parameters:

  • c_type (DType): Output dtype.
  • a_type (DType): Activation / weight / gamma dtype.
  • TraceBufT (TraceBuf): Trace-buffer implementation. Defaults to NullTrace for zero-overhead untraced runs.
  • transpose_b (Bool): If True, weight is row-major [N, K] used as weight.T.
  • pdl_level (PDLLevel): Programmatic Dependent Launch level.
  • tile_n (Int): Comptime tile width in columns.
  • num_threads (Int): Comptime threads per block.
  • enable_trace (Bool): When True, record per-phase timestamps into trace_buf. When False (default), all record sites compile away.

Args:

  • normed_output (TileTensor): [M, N_normed] output buffer. Holds rms_norm(y[:, :N_normed], gamma, eps) on return.
  • unnormed_output (TileTensor): [M, N - N_normed] output buffer. Holds y[:, N_normed:] on return.
  • act (TileTensor): [M, K] activations.
  • weight (TileTensor): [N, K] weights (when transpose_b=True).
  • gamma (TileTensor): [N_normed] RMS norm scale.
  • eps (Scalar): RMS norm epsilon.
  • finish_counter (UnsafePointer): Single-Int32 device counter. Must be zero-initialized on first use; the kernel resets it to zero before returning.
  • ctx (DeviceContext): Device context.
  • trace_buf (TraceBufT): Trace-buffer instance. Defaults to NullTrace().

Raises:

Error: If the kernel launch fails.

Was this page helpful?