Mojo struct
RingBuffer
struct RingBuffer[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy]
Ring buffer for coordinating producer-consumer warps in matrix multiplication.
Parametersβ
- βdtype (
DType): Data type of elements. - βlayout (
Layout): Memory layout for shared memory tiles. - βpipeline_stages (
Int): Number of stages for software pipelining. - βblock_rows (
Int): Number of rows in block-level tiles. - βblock_cols (
Int): Number of columns in block-level tiles. - βwarp_rows (
Int): Number of rows in warp-level tiles. - βwarp_cols (
Int): Number of columns in warp-level tiles. - βreads_per_warp_block (
Int): How many consumer warps read each tile. - βtile_buffers (
Int): Number of separate tile buffers (usually 1). - βsync_strategy_type (
SyncStrategy): Synchronization strategy (SingleCounterSync or SplitCounterSync).
Fieldsβ
- βsmem_buffers (
RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SMemBuffersType): - βsync_strategy (
sync_strategy_type):
Implemented traitsβ
AnyType,
ImplicitlyDestructible
comptime membersβ
block_warpsβ
comptime block_warps = (block_rows // warp_rows)
SMemBuffersTypeβ
comptime SMemBuffersType = StaticTuple[SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols], tile_buffers]
SmemBufferTypeβ
comptime SmemBufferType = SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols]
total_tilesβ
comptime total_tiles = (RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].block_warps * pipeline_stages)
WarpTileTupleTypeβ
comptime WarpTileTupleType = StaticTuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), linear_idx_type=_get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), masked=_tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols]() or _tile_is_masked[LayoutTensor._compute_tile_layout[block_rows, block_cols]()[0], warp_rows, warp_cols](), alignment=128], tile_buffers]
WarpTileTypeβ
comptime WarpTileType = RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SmemBufferType.WarpTileType
Methodsβ
__init__β
__init__(out self)
get_tilesβ
get_tiles(self, stage: Int, warp_tile_idx: Int) -> RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].WarpTileTupleType
Get tiles from shared memory.
Returns:
RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].WarpTileTupleType
producerβ
producer[warps_processed_per_producer: Int](mut self) -> ProducerView[origin_of(self), RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_processed_per_producer]
Create a producer view of this ring buffer.
Returns:
ProducerView[origin_of(self), RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_processed_per_producer]
consumerβ
consumer[warps_computed_per_consumer: Int](mut self) -> ConsumerView[origin_of(self), RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_computed_per_consumer]
Create a consumer view of this ring buffer.
Returns:
ConsumerView[origin_of(self), RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_computed_per_consumer]
get_staged_idxβ
get_staged_idx(self, tile_idx: Int, stage: Int) -> Int
Get the staged index for a tile and stage.
Returns:
wait_producer_acquireβ
wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)
Producer waits to acquire a tile.
signal_producer_releaseβ
signal_producer_release(mut self, tile_idx: Int, stage: Int)
Producer signals it has released a tile.
wait_consumer_acquireβ
wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)
Consumer waits to acquire a tile.
signal_consumer_releaseβ
signal_consumer_release(mut self, tile_idx: Int, stage: Int)
Consumer signals it has released a tile.
get_producer_phase_incrementβ
get_consumer_phase_incrementβ
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!