Skip to main content

Mojo struct

Conv2dProblemShape

struct Conv2dProblemShape

Defines 2D convolution problem geometry.

Layouts:

  • Activation: NHWC (batch, height, width, channels)
  • Filter: KRSC (output_channels, filter_h, filter_s, input_channels)
  • Output: NHWC (batch, out_height, out_width, output_channels)

For Fprop with stride=1, no dilation, this maps to GEMM as:

  • M = N * H_out * W_out
  • N = K (output channels)
  • K = C * R * S (input channels * filter area)

Fields

  • batch (Int):
  • in_height (Int):
  • in_width (Int):
  • in_channels (Int):
  • out_channels (Int):
  • filter_h (Int):
  • filter_w (Int):
  • pad_h (Int):
  • pad_w (Int):
  • stride_h (Int):
  • stride_w (Int):
  • dilation_h (Int):
  • dilation_w (Int):
  • groups (Int):

Implemented traits

AnyType, Copyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

Methods

__init__

__init__(out self, batch: Int, in_height: Int, in_width: Int, in_channels: Int, out_channels: Int, filter_h: Int, filter_w: Int, pad_h: Int = 0, pad_w: Int = 0, stride_h: Int = 1, stride_w: Int = 1, dilation_h: Int = 1, dilation_w: Int = 1, groups: Int = 1)

out_height

out_height(self) -> Int

Compute output height.

Returns:

Int

out_width

out_width(self) -> Int

Compute output width.

Returns:

Int

gemm_m

gemm_m(self) -> Int

GEMM M dimension = batch * output_height * output_width.

Returns:

Int

gemm_n

gemm_n(self) -> Int

GEMM N dimension = output_channels.

Returns:

Int

gemm_k

gemm_k(self) -> Int

GEMM K dimension = input_channels * filter_height * filter_width.

Returns:

Int

num_m_tiles

num_m_tiles(self, tile_m: Int) -> Int

Number of tiles in M dimension.

Returns:

Int

num_n_tiles

num_n_tiles(self, tile_n: Int) -> Int

Number of tiles in N dimension.

Returns:

Int

num_k_tiles

num_k_tiles(self, tile_k: Int) -> Int

Number of tiles in K dimension.

Returns:

Int

Was this page helpful?