Mojo struct
TilePipeline
@register_passable(trivial)
struct TilePipeline[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]
Staged tile storage with producer-consumer synchronization for SM100.
Manages a fixed set of pipeline stages (not a FIFO queue) where:
- Producer (TMA Load) fills tiles into the current stage
- Consumer (MMA) reads tiles from the current stage
- Barriers coordinate access between producer and consumer
Template Parameters: a_type: Data type for A matrix tiles. b_type: Data type for B matrix tiles. a_tile_layout: Memory layout for A tiles. b_tile_layout: Memory layout for B tiles. num_pipeline_stages: Total number of tile stages (stages * k_group_size). num_group_stages: Number of synchronization stages. k_group_size: Number of tiles per synchronization stage.
Fields
- pipeline (
TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].Pipeline): - a_tiles (
TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray): - b_tiles (
TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
ATile
comptime ATile = TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray.Tile
ATileArray
comptime ATileArray = SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128]
BarrierArray
comptime BarrierArray = SMemArrayType[SharedMemBarrier, (num_group_stages * 2)]
BTile
comptime BTile = TilePipeline[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray.Tile
BTileArray
comptime BTileArray = SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]
Pipeline
comptime Pipeline = ProducerConsumerPipeline[num_group_stages]
Methods
__init__
__init__(barriers: SMemArrayType[SharedMemBarrier, (num_group_stages * 2)], a_tiles: SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128], b_tiles: SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]) -> Self
Initialize from typed barrier array and tile arrays.
init_barriers
static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)
Initialize pipeline barriers. Called once by elect_one thread.
producer
producer[origin: MutOrigin](ref [origin] self) -> InputProducer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]
Get producer view for TMA Load warp.
Returns:
InputProducer
consumer
consumer[origin: MutOrigin](ref [origin] self) -> InputConsumer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]
Get consumer view for MMA warp.
Returns:
InputConsumer
producer_stage
consumer_stage
producer_mbar
producer_mbar(self, stage: UInt32) -> MbarPtr
Returns:
MbarPtr
consumer_mbar
consumer_mbar(self, stage: UInt32) -> MbarPtr
Returns:
MbarPtr
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!