Mojo function
dispatch_sm100_conv2d
dispatch_sm100_conv2d[input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, filter_is_fcrs: Bool, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, has_residual: Bool = False](input: LayoutTensor[input_type, input_layout, input.origin, address_space=input.address_space, element_layout=input.element_layout, layout_int_type=input.layout_int_type, linear_idx_type=input.linear_idx_type, masked=input.masked, alignment=input.alignment], filter: LayoutTensor[filter_type, filter_layout, filter.origin, address_space=filter.address_space, element_layout=filter.element_layout, layout_int_type=filter.layout_int_type, linear_idx_type=filter.linear_idx_type, masked=filter.masked, alignment=filter.alignment], output: LayoutTensor[output_type, output_layout, output.origin, address_space=output.address_space, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], symmetric_padding: IndexList[2], ctx: DeviceContext, source_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin] = UnsafePointer(), beta: Float32 = 0)
Dispatch to SM100 structured conv2d with filter transpose.
This function gates the SM100 kernel import behind @parameter if on dtype, so the kernel is never compiled for unsupported dtypes.
Parameters:
- input_layout (
Layout): Layout of the input activation tensor. - filter_layout (
Layout): Layout of the filter weights tensor. - output_layout (
Layout): Layout of the output tensor. - input_type (
DType): Data type of the input activation tensor. - filter_type (
DType): Data type of the filter weights tensor. - output_type (
DType): Data type of the output tensor. - filter_is_fcrs (
Bool): If True, filter is FCRS layout; otherwise RSCF. - elementwise_lambda_fn (
Optional): Optional void epilogue lambda applied after output write. Signature:fn(IndexList[2], SIMD) -> None. - has_residual (
Bool): If True, fuse residual add D = Conv(A,B) + beta*C.
Args:
- input (
LayoutTensor): Input activation tensor in NHWC layout. - filter (
LayoutTensor): Filter weights tensor. - output (
LayoutTensor): Output tensor in NHWC layout. - symmetric_padding (
IndexList): Symmetric padding (pad_h, pad_w). - ctx (
DeviceContext): Device context for kernel launch. - source_ptr (
UnsafePointer): Pointer to residual source tensor C (NHWC, same shape as output). Only used when has_residual is True. - beta (
Float32): Residual scale factor. D = Conv(A,B) + beta*C.
Raises:
Error if kernel launch fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!