Skip to main content

Mojo function

dispatch_sm100_conv2d

dispatch_sm100_conv2d[input_type: DType, filter_type: DType, output_type: DType, filter_is_fcrs: Bool, elementwise_lambda_fn: Optional[elementwise_epilogue_type] = None, has_residual: Bool = False](input: TileTensor[input_type, input.LayoutType, input.origin, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], filter: TileTensor[filter_type, filter.LayoutType, filter.origin, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size], output: TileTensor[output_type, output.LayoutType, output.origin, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], symmetric_padding: IndexList[2], ctx: DeviceContext, source_ptr: UnsafePointer[Scalar[output_type], MutAnyOrigin] = UnsafePointer(), beta: Float32 = 0)

Dispatch to SM100 structured conv2d with filter transpose.

This function gates the SM100 kernel import behind @parameter if on dtype, so the kernel is never compiled for unsupported dtypes.

Parameters:

  • input_type (DType): Data type of the input activation tensor.
  • filter_type (DType): Data type of the filter weights tensor.
  • output_type (DType): Data type of the output tensor.
  • filter_is_fcrs (Bool): If True, filter is FCRS layout; otherwise RSCF.
  • elementwise_lambda_fn (Optional): Optional void epilogue lambda applied after output write. Signature: def(IndexList[2], SIMD) -> None.
  • has_residual (Bool): If True, fuse residual add D = Conv(A,B) + beta*C.

Args:

  • input (TileTensor): Input activation tensor in NHWC layout.
  • filter (TileTensor): Filter weights tensor.
  • output (TileTensor): Output tensor in NHWC layout.
  • symmetric_padding (IndexList): Symmetric padding (pad_h, pad_w).
  • ctx (DeviceContext): Device context for kernel launch.
  • source_ptr (UnsafePointer): Pointer to residual source tensor C (NHWC, same shape as output). Only used when has_residual is True.
  • beta (Float32): Residual scale factor. D = Conv(A,B) + beta*C.

Raises:

Error if kernel launch fails.

Was this page helpful?