Mojo function
gemv_and_partial_norm
gemv_and_partial_norm[c_type: DType, a_type: DType, //, *, transpose_b: Bool = True, fused: Bool = True, tile_n: Int = 4, num_threads: Int = 256, pdl_level: PDLLevel = PDLLevel()](normed_output: TileTensor[c_type, normed_output.LayoutType, normed_output.origin, address_space=normed_output.address_space, linear_idx_type=normed_output.linear_idx_type, element_size=normed_output.element_size], unnormed_output: TileTensor[c_type, unnormed_output.LayoutType, unnormed_output.origin, address_space=unnormed_output.address_space, linear_idx_type=unnormed_output.linear_idx_type, element_size=unnormed_output.element_size], act: TileTensor[a_type, act.LayoutType, act.origin, address_space=act.address_space, linear_idx_type=act.linear_idx_type, element_size=act.element_size], weight: TileTensor[a_type, weight.LayoutType, weight.origin, address_space=weight.address_space, linear_idx_type=weight.linear_idx_type, element_size=weight.element_size], gamma: TileTensor[a_type, gamma.LayoutType, gamma.origin, address_space=gamma.address_space, linear_idx_type=gamma.linear_idx_type, element_size=gamma.element_size], eps: Scalar[a_type], ctx: DeviceContext)
Computes y = act @ weight.T, then partitions y into a normed front and an unnormed tail.
Parameters:
- c_type (
DType): Output dtype. - a_type (
DType): Activation / weight / gamma dtype. - transpose_b (
Bool): IfTrue,weightis row-major[N, K]used asweight.T. - fused (
Bool): Compile-time flag.True(default) selects the single- kernel fused path (M=1 only).Falseselects the 2-launch baseline (matmul +rms_norm_gpu; the unnormed tail is a view into the matmul output, sounnormed_outputis left untouched). - tile_n (
Int): Comptime tile width in columns (fused only). - num_threads (
Int): Comptime threads per block (fused only). - pdl_level (
PDLLevel): Programmatic Dependent Launch level.
Args:
- normed_output (
TileTensor):[M, N_normed]output buffer. Receivesrms_norm(y[:, :N_normed], gamma, eps)in both paths. - unnormed_output (
TileTensor):[M, N - N_normed]output buffer. The fused path writesy[:, N_normed:]here; the unfused path leaves this untouched (the unnormed tail is a view into the internally-allocated matmul scratch). - act (
TileTensor):[M, K]activations. - weight (
TileTensor):[N, K]weights (whentranspose_b=True). - gamma (
TileTensor):[N_normed]RMS norm scale. - eps (
Scalar): RMS norm epsilon. - ctx (
DeviceContext): Device context.
Raises:
Error: If _matmul_gpu or rms_norm_gpu fail to launch, or
if internal scratch allocation fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!