Skip to main content

Mojo struct

TilePipeline

@register_passable(trivial) struct TilePipeline[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]

Staged tile storage with producer-consumer synchronization for SM100.

Manages a fixed set of pipeline stages (not a FIFO queue) where:

  • Producer (TMA Load) fills tiles into the current stage
  • Consumer (MMA) reads tiles from the current stage
  • Barriers coordinate access between producer and consumer

Template Parameters: a_type: Data type for A matrix tiles. b_type: Data type for B matrix tiles. a_tile_layout: Memory layout for A tiles. b_tile_layout: Memory layout for B tiles. num_pipeline_stages: Total number of tile stages (stages * k_group_size). num_group_stages: Number of synchronization stages. k_group_size: Number of tiles per synchronization stage.

Fields

  • pipeline (TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].Pipeline):
  • a_tiles (TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray):
  • b_tiles (TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray):

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = True

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

ATile

comptime ATile = TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray.Tile

ATileArray

comptime ATileArray = SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128]

BarrierArray

comptime BarrierArray = SMemArrayType[SharedMemBarrier, (num_group_stages * 2)]

BTile

comptime BTile = TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray.Tile

BTileArray

comptime BTileArray = SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]

Pipeline

comptime Pipeline = ProducerConsumerPipeline[num_group_stages]

Methods

__init__

__init__(barriers: SMemArrayType[SharedMemBarrier, (num_group_stages * 2)], a_tiles: SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128], b_tiles: SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]) -> Self

Initialize from typed barrier array and tile arrays.

init_barriers

static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)

Initialize pipeline barriers. Called once by elect_one thread.

producer

producer[origin: MutOrigin](ref [origin] self) -> InputProducer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]

Get producer view for TMA Load warp.

Returns:

InputProducer

consumer

consumer[origin: MutOrigin](ref [origin] self) -> InputConsumer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]

Get consumer view for MMA warp.

Returns:

InputConsumer

producer_stage

producer_stage(self) -> UInt32

Returns:

UInt32

consumer_stage

consumer_stage(self) -> UInt32

Returns:

UInt32

producer_mbar

producer_mbar(self, stage: UInt32) -> MbarPtr

Returns:

MbarPtr

consumer_mbar

consumer_mbar(self, stage: UInt32) -> MbarPtr

Returns:

MbarPtr

Was this page helpful?