Skip to main content

Mojo struct

Conv2dSmem

struct Conv2dSmem[act_type: DType, filter_type: DType, out_type: DType, *, config: Conv2dConfig[act_type, filter_type, out_type]]

Shared memory layout for SM100 Conv2D fprop kernel.

This struct manages shared memory allocation for:

  • Activation tiles (after im2col transformation)
  • Filter tiles
  • Output tiles for accumulation
  • Synchronization barriers

The layout mirrors B200MatmulSmem but with conv-specific semantics:

  • A tiles = im2col'd activation (M x K where M = NHW, K = CRS)
  • B tiles = filter (transposed, K x N where K = CRS, N = K_out)
  • C tiles = output (M x N)

Parameters

  • act_type (DType): Activation data type.
  • filter_type (DType): Filter data type.
  • out_type (DType): Output data type.
  • config (Conv2dConfig): Kernel configuration.

Fields

  • input_tiles (Conv2dSmem[act_type, filter_type, out_type, config=config].InputTiles):
  • output_tiles (Conv2dSmem[act_type, filter_type, out_type, config=config].OutputTiles):
  • source_tiles (Conv2dSmem[act_type, filter_type, out_type, config=config].SourceTiles):
  • pipelines (Conv2dSmem[act_type, filter_type, out_type, config=config].Pipelines):
  • epi_load_pipeline (Conv2dSmem[act_type, filter_type, out_type, config=config].EpiLoadPipeline):
  • load_order_barrier (Conv2dSmem[act_type, filter_type, out_type, config=config].LoadOrderBarrier):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

act_smem_elements

comptime act_smem_elements = Layout[Coord[ComptimeInt[8], ComptimeInt[(Conv2dSmem[act_type, filter_type, out_type, config=config].BM // 8)]], Coord[ComptimeInt[(config.a_swizzle.bytes() // size_of[act_type]())], ComptimeInt[(Conv2dSmem[act_type, filter_type, out_type, config=config].BK // (config.a_swizzle.bytes() // size_of[act_type]()))]], Coord[ComptimeInt[(config.a_swizzle.bytes() // size_of[act_type]())], ComptimeInt[(8 * (config.a_swizzle.bytes() // size_of[act_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (Conv2dSmem[act_type, filter_type, out_type, config=config].BK == (config.a_swizzle.bytes() // size_of[act_type]())) else (Conv2dSmem[act_type, filter_type, out_type, config=config].BM * (config.a_swizzle.bytes() // size_of[act_type]()))]]].static_product

ActTileArray

comptime ActTileArray = Conv2dSmem[act_type, filter_type, out_type, config=config].InputTiles.ATileArray

BK

comptime BK = config.block_tile_shape[2]

BM

comptime BM = config.block_tile_shape[0]

BN

comptime BN = config.block_tile_shape[1]

EpiLoadBarriers

comptime EpiLoadBarriers = Conv2dSmem[act_type, filter_type, out_type, config=config].EpiLoadPipeline.BarrierArray

EpiLoadPipeline

comptime EpiLoadPipeline = EpiLoadPipelineStorage[Conv2dSmem[act_type, filter_type, out_type, config=config].num_epi_load_stages]

filter_smem_elements

comptime filter_smem_elements = Layout[Coord[ComptimeInt[8], ComptimeInt[(Conv2dSmem[act_type, filter_type, out_type, config=config].BN // 8)]], Coord[ComptimeInt[(config.b_swizzle.bytes() // size_of[filter_type]())], ComptimeInt[(Conv2dSmem[act_type, filter_type, out_type, config=config].BK // (config.b_swizzle.bytes() // size_of[filter_type]()))]], Coord[ComptimeInt[(config.b_swizzle.bytes() // size_of[filter_type]())], ComptimeInt[(8 * (config.b_swizzle.bytes() // size_of[filter_type]()))]], Coord[ComptimeInt[1], ComptimeInt[0 if (Conv2dSmem[act_type, filter_type, out_type, config=config].BK == (config.b_swizzle.bytes() // size_of[filter_type]())) else (Conv2dSmem[act_type, filter_type, out_type, config=config].BN * (config.b_swizzle.bytes() // size_of[filter_type]()))]]].static_product

FilterTileArray

comptime FilterTileArray = Conv2dSmem[act_type, filter_type, out_type, config=config].InputTiles.BTileArray

InputTiles

comptime InputTiles = StandardTileStorage[act_type, filter_type, IndexList(VariadicList(Conv2dSmem[act_type, filter_type, out_type, config=config].BM, Conv2dSmem[act_type, filter_type, out_type, config=config].BK), Tuple()), IndexList(VariadicList(Conv2dSmem[act_type, filter_type, out_type, config=config].BN, Conv2dSmem[act_type, filter_type, out_type, config=config].BK), Tuple()), Conv2dSmem[act_type, filter_type, out_type, config=config].num_pipeline_stages]

LoadOrderBarrier

comptime LoadOrderBarrier = LoadOrderBarrierStorage

LoadOrderBarriers

comptime LoadOrderBarriers = Conv2dSmem[act_type, filter_type, out_type, config=config].LoadOrderBarrier.BarrierArray

num_accum_pipeline_stages

comptime num_accum_pipeline_stages = config.num_accum_pipeline_stages

num_clc_pipeline_stages

comptime num_clc_pipeline_stages = config.num_clc_pipeline_stages

num_epi_load_stages

comptime num_epi_load_stages = 2

num_group_pipeline_stages

comptime num_group_pipeline_stages = (Conv2dSmem[act_type, filter_type, out_type, config=config].num_pipeline_stages // config)

num_output_stages

comptime num_output_stages = config.num_output_stages

num_pipeline_stages

comptime num_pipeline_stages = config.num_pipeline_stages

out_smem_layout

comptime out_smem_layout = Layout.row_major(VariadicList(Conv2dSmem[act_type, filter_type, out_type, config=config].OutputM, Conv2dSmem[act_type, filter_type, out_type, config=config].OutputN))

OutputM

comptime OutputM = config.output_tile_shape[0]

OutputN

comptime OutputN = config.output_tile_shape[1]

OutputTiles

comptime OutputTiles = OutputTileStorage[out_type, Conv2dSmem[act_type, filter_type, out_type, config=config].OutputM, Conv2dSmem[act_type, filter_type, out_type, config=config].OutputN, Conv2dSmem[act_type, filter_type, out_type, config=config].num_output_stages]

OutTileArray

comptime OutTileArray = Conv2dSmem[act_type, filter_type, out_type, config=config].OutputTiles.CTileArray

Pipelines

comptime Pipelines = SmemPipelineBundle[Conv2dSmem[act_type, filter_type, out_type, config=config].num_group_pipeline_stages, Conv2dSmem[act_type, filter_type, out_type, config=config].num_accum_pipeline_stages, Conv2dSmem[act_type, filter_type, out_type, config=config].num_clc_pipeline_stages, StandardTilePayload[act_type, filter_type, IndexList(VariadicList(Conv2dSmem[act_type, filter_type, out_type, config=config].BM, Conv2dSmem[act_type, filter_type, out_type, config=config].BK), Tuple()), IndexList(VariadicList(Conv2dSmem[act_type, filter_type, out_type, config=config].BN, Conv2dSmem[act_type, filter_type, out_type, config=config].BK), Tuple()), Conv2dSmem[act_type, filter_type, out_type, config=config].num_pipeline_stages]]

SourceTiles

comptime SourceTiles = SourceTileStorage[out_type, IndexList(VariadicList(Conv2dSmem[act_type, filter_type, out_type, config=config].OutputM, Conv2dSmem[act_type, filter_type, out_type, config=config].OutputN), Tuple()), Conv2dSmem[act_type, filter_type, out_type, config=config].num_epi_load_stages]

SrcTileArray

comptime SrcTileArray = Conv2dSmem[act_type, filter_type, out_type, config=config].SourceTiles.SrcTileArray

Methods

act_tiles

act_tiles(ref[AddressSpace._value] self) -> Conv2dSmem[act_type, filter_type, out_type, config=config].ActTileArray

Get activation tiles (im2col'd).

Returns:

Conv2dSmem

filter_tiles

filter_tiles(ref[AddressSpace._value] self) -> Conv2dSmem[act_type, filter_type, out_type, config=config].FilterTileArray

Get filter tiles.

Returns:

Conv2dSmem

out_tiles

out_tiles(ref[AddressSpace._value] self) -> Conv2dSmem[act_type, filter_type, out_type, config=config].OutTileArray

Get output tiles.

Returns:

Conv2dSmem

src_tiles

src_tiles(ref[AddressSpace._value] self) -> Conv2dSmem[act_type, filter_type, out_type, config=config].SrcTileArray

Get source C tiles (for residual operations).

Returns:

Conv2dSmem

epi_load_barriers

epi_load_barriers(ref[AddressSpace._value] self) -> Conv2dSmem[act_type, filter_type, out_type, config=config].EpiLoadBarriers

Get epilogue load pipeline barriers.

Used for synchronization between EpilogueLoad warp (producer) and Epilogue warps (consumers) for source C tensor loading.

Returns:

Conv2dSmem

get_load_order_barrier

get_load_order_barrier(ref[AddressSpace._value] self) -> Conv2dSmem[act_type, filter_type, out_type, config=config].LoadOrderBarriers

Get load order barrier.

Used to coordinate MainLoad warp with EpilogueLoad warp, ensuring epilogue loads don't start before mainloop prologue completes.

Returns:

Conv2dSmem

Was this page helpful?