Mojo function
gemv_and_partial_norm_with_scratch
gemv_and_partial_norm_with_scratch[c_type: DType, a_type: DType, TraceBufT: TraceBuf = NullTrace, //, *, transpose_b: Bool = True, pdl_level: PDLLevel = PDLLevel(), tile_n: Int = 4, num_threads: Int = 256, enable_trace: Bool = False](normed_output: TileTensor[c_type, normed_output.LayoutType, normed_output.origin, address_space=normed_output.address_space, linear_idx_type=normed_output.linear_idx_type, element_size=normed_output.element_size], unnormed_output: TileTensor[c_type, unnormed_output.LayoutType, unnormed_output.origin, address_space=unnormed_output.address_space, linear_idx_type=unnormed_output.linear_idx_type, element_size=unnormed_output.element_size], act: TileTensor[a_type, act.LayoutType, act.origin, address_space=act.address_space, linear_idx_type=act.linear_idx_type, element_size=act.element_size], weight: TileTensor[a_type, weight.LayoutType, weight.origin, address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size], gamma: TileTensor[a_type, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], eps: Scalar[a_type], finish_counter: UnsafePointer[Int32, MutAnyOrigin], ctx: DeviceContext, trace_buf: TraceBufT = NullTrace())
Fused path with caller-provided scratch.
finish_counter must be zero-initialized on first use. The kernel
resets it to zero before returning so the same buffer can be
reused across calls without an external memset.
Set enable_trace=True and pass a GmemTrace(ptr) to record per-
block timestamps into ptr (sized num_normed_blocks * GEMV_TRACE_EVENTS_PER_BLOCK u64s). When disabled (default), the
trace path dead-code-eliminates, yielding byte-identical PTX to
the untraced kernel.
Parameters:
- c_type (
DType): Output dtype. - a_type (
DType): Activation / weight / gamma dtype. - TraceBufT (
TraceBuf): Trace-buffer implementation. Defaults toNullTracefor zero-overhead untraced runs. - transpose_b (
Bool): IfTrue,weightis row-major[N, K]used asweight.T. - pdl_level (
PDLLevel): Programmatic Dependent Launch level. - tile_n (
Int): Comptime tile width in columns. - num_threads (
Int): Comptime threads per block. - enable_trace (
Bool): WhenTrue, record per-phase timestamps intotrace_buf. WhenFalse(default), all record sites compile away.
Args:
- normed_output (
TileTensor):[M, N_normed]output buffer. Holdsrms_norm(y[:, :N_normed], gamma, eps)on return. - unnormed_output (
TileTensor):[M, N - N_normed]output buffer. Holdsy[:, N_normed:]on return. - act (
TileTensor):[M, K]activations. - weight (
TileTensor):[N, K]weights (whentranspose_b=True). - gamma (
TileTensor):[N_normed]RMS norm scale. - eps (
Scalar): RMS norm epsilon. - finish_counter (
UnsafePointer): Single-Int32device counter. Must be zero-initialized on first use; the kernel resets it to zero before returning. - ctx (
DeviceContext): Device context. - trace_buf (
TraceBufT): Trace-buffer instance. Defaults toNullTrace().
Raises:
Error: If the kernel launch fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!