Mojo struct
RingBuffer
@register_passable(trivial)
struct RingBuffer[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]
Ring buffer with tile storage for SM100 producer-consumer sync.
This is the SM90-style API where tiles are stored in the ring buffer and returned directly from get_tiles().
Template Parameters: a_type: Data type for A matrix tiles. b_type: Data type for B matrix tiles. a_tile_layout: Memory layout for A tiles. b_tile_layout: Memory layout for B tiles. num_pipeline_stages: Total number of tile stages. num_group_stages: Number of synchronization stages. k_group_size: Number of tiles per synchronization stage.
Fields
- pipeline (
RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].Pipeline): - a_tiles (
RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray): - b_tiles (
RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
Movable,
UnknownDestructibility
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
ATile
comptime ATile = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray.Tile
ATileArray
comptime ATileArray = SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128]
BTile
comptime BTile = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray.Tile
BTileArray
comptime BTileArray = SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]
Pipeline
comptime Pipeline = ProducerConsumerPipeline[num_group_stages]
Methods
__init__
__init__(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], a_tiles: SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128], b_tiles: SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]) -> Self
Initialize ring buffer from storage pointer.
Creates pipeline internally from storage pointer. Barriers must be initialized via init_barriers() before first use.
Args:
- storage_ptr (
LegacyUnsafePointer): Pointer to shared memory barrier storage. - a_tiles (
SMemTileArrayType): A matrix tile array in shared memory. - b_tiles (
SMemTileArrayType): B matrix tile array in shared memory.
init_barriers
static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)
Initialize pipeline barriers. Called once by elect_one thread.
Args:
- storage_ptr (
LegacyUnsafePointer): Pointer to shared memory barrier storage. - producer_arv_count (
Int32): Expected arrival count for producer barriers. - consumer_arv_count (
Int32): Expected arrival count for consumer barriers.
producer
producer[origin: MutOrigin](ref [origin] self) -> Producer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]
Get producer view with get_tiles() API.
Returns:
Producer
consumer
consumer[origin: MutOrigin](ref [origin] self) -> Consumer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]
Get consumer view with get_tiles() API.
Returns:
Consumer
get_producer_tiles
get_producer_tiles(mut self) -> Tuple[UInt32, MbarPtr, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray]
Wait for slot and return stage, barrier, and tile arrays.
Synchronization is handled internally - waits for consumer to release slot. Use stage to index: tiles.a_tiles[stage * k_group_size + j]
Returns:
Tuple: Tuple of (stage, barrier, a_tiles, b_tiles).
enqueue_tile
enqueue_tile(mut self)
Signal producer finished loading and advance stage.
get_tile
get_tile[tile_idx_in_group: Int](self, stage: UInt32) -> Tuple[RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATile, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTile]
Get tiles at specific index within the current k_group.
Returns:
get_consumer_tiles
get_consumer_tiles(mut self) -> Tuple[UInt32, MbarPtr, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray]
Wait for slot and return stage, barrier, and tile arrays.
Synchronization is handled internally - waits for producer to fill slot. Use stage to index: tiles.a_tiles[stage * k_group_size + j]
Returns:
Tuple: Tuple of (stage, mbar, a_tiles, b_tiles).
release_slot
release_slot(mut self)
Signal consumer finished and advance stage.
producer_stage
consumer_stage
producer_mbar
producer_mbar(self, stage: UInt32) -> MbarPtr
Returns:
MbarPtr
consumer_mbar
consumer_mbar(self, stage: UInt32) -> MbarPtr
Returns:
MbarPtr
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!