Mojo function
dispatch_sm100_conv2d
dispatch_sm100_conv2d[input_type: DType, filter_type: DType, output_type: DType, //, filter_is_fcrs: Bool = False, elementwise_lambda_fn: Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, has_residual: Bool = False](input: TileTensor[input_type, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size], filter: TileTensor[filter_type, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size], output: TileTensor[output_type, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], symmetric_padding: IndexList[2], ctx: DeviceContext, source_ptr: OptionalReg[UnsafePointer[Scalar[output_type], MutAnyOrigin]] = None, beta: Float32 = 0)
Dispatch to SM100 structured conv2d with filter transpose.
This function gates the SM100 kernel import behind @parameter if on dtype, so the kernel is never compiled for unsupported dtypes.
Parameters:
- βinput_type (
DType): Data type of the input activation tensor. - βfilter_type (
DType): Data type of the filter weights tensor. - βoutput_type (
DType): Data type of the output tensor. - βfilter_is_fcrs (
Bool): If True, filter is FCRS layout; otherwise RSCF. - βelementwise_lambda_fn (
Optional[def[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None]): Optional void epilogue lambda applied after output write. Signature:def(IndexList[2], SIMD) -> None. - βhas_residual (
Bool): If True, fuse residual add D = Conv(A,B) + beta*C.
Args:
- βinput (
TileTensor[input_type, address_space=input.address_space, linear_idx_type=input.linear_idx_type, element_size=input.element_size]): Input activation tensor in NHWC layout. - βfilter (
TileTensor[filter_type, address_space=filter.address_space, linear_idx_type=filter.linear_idx_type, element_size=filter.element_size]): Filter weights tensor. - βoutput (
TileTensor[output_type, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Output tensor in NHWC layout. - βsymmetric_padding (
IndexList[2]): Symmetric padding (pad_h, pad_w). - βctx (
DeviceContext): Device context for kernel launch. - βsource_ptr (
OptionalReg[UnsafePointer[Scalar[output_type], MutAnyOrigin]]): Pointer to residual source tensor C (NHWC, same shape as output). Only used when has_residual is True. - βbeta (
Float32): Residual scale factor. D = Conv(A,B) + beta*C.
Raises:
Error if kernel launch fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!