Skip to main content

Mojo struct

RingBuffer

struct RingBuffer[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy]

Ring buffer for coordinating producer-consumer warps in matrix multiplication.

Parameters

  • dtype (DType): Data type of elements.
  • layout (Layout): Memory layout for shared memory tiles.
  • pipeline_stages (Int): Number of stages for software pipelining.
  • block_rows (Int): Number of rows in block-level tiles.
  • block_cols (Int): Number of columns in block-level tiles.
  • warp_rows (Int): Number of rows in warp-level tiles.
  • warp_cols (Int): Number of columns in warp-level tiles.
  • reads_per_warp_block (Int): How many consumer warps read each tile.
  • tile_buffers (Int): Number of separate tile buffers (usually 1).
  • sync_strategy_type (SyncStrategy): Synchronization strategy (SingleCounterSync or SplitCounterSync).

Fields

  • smem_buffers (RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SMemBuffersType):
  • sync_strategy (sync_strategy_type):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

comptime __del__is_trivial = sync_strategy_type.__del__is_trivial

block_warps

comptime block_warps = (block_rows // warp_rows)

SMemBuffersType

comptime SMemBuffersType = StaticTuple[SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols], tile_buffers]

SmemBufferType

comptime SmemBufferType = SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols]

total_tiles

comptime total_tiles = (RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].block_warps * pipeline_stages)

WarpTileTupleType

comptime WarpTileTupleType = StaticTuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, block_rows, block_cols]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols](), 128, warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), linear_idx_type=_get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), masked=_tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols]() if _tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, block_rows, block_cols]()[0], warp_rows, warp_cols](), alignment=128], tile_buffers]

WarpTileType

comptime WarpTileType = RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SmemBufferType.WarpTileType

Methods

__init__

__init__(out self)

get_tiles

get_tiles(self, stage: Int, warp_tile_idx: Int) -> RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].WarpTileTupleType

Get tiles from shared memory.

Returns:

RingBuffer

producer

producer[warps_processed_per_producer: Int](mut self) -> ProducerView[self, RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_processed_per_producer]

Create a producer view of this ring buffer.

Returns:

ProducerView

consumer

consumer[warps_computed_per_consumer: Int](mut self) -> ConsumerView[self, RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_computed_per_consumer]

Create a consumer view of this ring buffer.

Returns:

ConsumerView

get_staged_idx

get_staged_idx(self, tile_idx: Int, stage: Int) -> Int

Get the staged index for a tile and stage.

Returns:

Int

wait_producer_acquire

wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)

Producer waits to acquire a tile.

signal_producer_release

signal_producer_release(mut self, tile_idx: Int, stage: Int)

Producer signals it has released a tile.

wait_consumer_acquire

wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)

Consumer waits to acquire a tile.

signal_consumer_release

signal_consumer_release(mut self, tile_idx: Int, stage: Int)

Consumer signals it has released a tile.

get_producer_phase_increment

get_producer_phase_increment(self) -> Int32

Get the phase increment for producers.

Returns:

Int32

get_consumer_phase_increment

get_consumer_phase_increment(self) -> Int32

Get the phase increment for consumers.

Returns:

Int32

Was this page helpful?