Mojo struct
RingBuffer
struct RingBuffer[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy]
Ring buffer for coordinating producer-consumer warps in matrix multiplication.
Parameters
- dtype (
DType): Data type of elements. - layout (
Layout): Memory layout for shared memory tiles. - pipeline_stages (
Int): Number of stages for software pipelining. - block_rows (
Int): Number of rows in block-level tiles. - block_cols (
Int): Number of columns in block-level tiles. - warp_rows (
Int): Number of rows in warp-level tiles. - warp_cols (
Int): Number of columns in warp-level tiles. - reads_per_warp_block (
Int): How many consumer warps read each tile. - tile_buffers (
Int): Number of separate tile buffers (usually 1). - sync_strategy_type (
SyncStrategy): Synchronization strategy (SingleCounterSync or SplitCounterSync).
Fields
- smem_buffers (
RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SMemBuffersType): - sync_strategy (
sync_strategy_type):
Implemented traits
AnyType,
UnknownDestructibility
Aliases
__del__is_trivial
comptime __del__is_trivial = sync_strategy_type.__del__is_trivial
block_warps
comptime block_warps = (block_rows // warp_rows)
SMemBuffersType
comptime SMemBuffersType = StaticTuple[SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols], tile_buffers]
SmemBufferType
comptime SmemBufferType = SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols]
total_tiles
comptime total_tiles = (RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].block_warps * pipeline_stages)
WarpTileTupleType
comptime WarpTileTupleType = StaticTuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, block_rows, block_cols]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols](), 128, warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), linear_idx_type=_get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), masked=_tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols]() if _tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, block_rows, block_cols]()[0], warp_rows, warp_cols](), alignment=128], tile_buffers]
WarpTileType
comptime WarpTileType = RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SmemBufferType.WarpTileType
Methods
__init__
__init__(out self)
get_tiles
get_tiles(self, stage: Int, warp_tile_idx: Int) -> RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].WarpTileTupleType
Get tiles from shared memory.
Returns:
RingBuffer
producer
producer[warps_processed_per_producer: Int](mut self) -> ProducerView[self, RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_processed_per_producer]
Create a producer view of this ring buffer.
Returns:
ProducerView
consumer
consumer[warps_computed_per_consumer: Int](mut self) -> ConsumerView[self, RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_computed_per_consumer]
Create a consumer view of this ring buffer.
Returns:
ConsumerView
get_staged_idx
get_staged_idx(self, tile_idx: Int, stage: Int) -> Int
Get the staged index for a tile and stage.
Returns:
wait_producer_acquire
wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)
Producer waits to acquire a tile.
signal_producer_release
signal_producer_release(mut self, tile_idx: Int, stage: Int)
Producer signals it has released a tile.
wait_consumer_acquire
wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)
Consumer waits to acquire a tile.
signal_consumer_release
signal_consumer_release(mut self, tile_idx: Int, stage: Int)
Consumer signals it has released a tile.
get_producer_phase_increment
get_consumer_phase_increment
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!