Skip to main content

Mojo struct

RingBuffer

struct RingBuffer[pipeline_stages: Int, a_buffer_layout: Layout, b_buffer_layout: Layout, SmemBufferDataType: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, //, SmemBufferTypeA: AnyStruct[SharedMemoryBuffer[SmemBufferDataType, a_buffer_layout, pipeline_stages, BM, BK, WM, WK]], SmemBufferTypeB: AnyStruct[SharedMemoryBuffer[SmemBufferDataType, b_buffer_layout, pipeline_stages, BN, BK, WN, WK]], consumer_warps: Int]

Manages access to shared memory tiles using barriers based in shared memory.

Fields

  • barrier_a (LayoutTensor[DType.int32, Layout.row_major((BM // WM), (consumer_warps * pipeline_stages)), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]):
  • barrier_b (LayoutTensor[DType.int32, Layout.row_major((BN // WN), (consumer_warps * pipeline_stages)), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]):
  • smem_buffer_a (SmemBufferTypeA):
  • smem_buffer_b (SmemBufferTypeB):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = SmemBufferTypeB.__del__is_trivial if SmemBufferTypeA.__del__is_trivial else SmemBufferTypeA.__del__is_trivial

BarrierTensorType

alias BarrierTensorType[warp_tile_count: Int] = LayoutTensor[DType.int32, Layout.row_major(warp_tile_count, (consumer_warps * pipeline_stages)), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]

Parameters

  • warp_tile_count (Int):

SharedMemoryBufferType

alias SharedMemoryBufferType[is_a: Bool] = SharedMemoryBuffer[SmemBufferDataType, a_buffer_layout if is_a else b_buffer_layout, pipeline_stages, BM if is_a else BN, BK, WM if is_a else WN, WK]

Parameters

warps_per_block_m

alias warps_per_block_m = (BM // WM)

warps_per_block_n

alias warps_per_block_n = (BN // WN)

Methods

__init__

__init__(out self, smem_buffer_a: SharedMemoryBuffer[SmemBufferDataType, a_buffer_layout, pipeline_stages, BM, BK, WM, WK], smem_buffer_b: SharedMemoryBuffer[SmemBufferDataType, b_buffer_layout, pipeline_stages, BN, BK, WN, WK])

await_shared_memory_warp_tile

await_shared_memory_warp_tile[is_a: Bool, is_producer: Bool](mut self, mut phase: Int, stage: Int, tile_idx: Int) -> LayoutTensor[SmemBufferDataType, LayoutTensor._compute_tile_layout[True, SmemBufferDataType, LayoutTensor._compute_tile_layout[True, SmemBufferDataType, pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), MutableAnyOrigin, AddressSpace(3), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), _get_index_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), False, 128, BM if is_a else BN, BK]()[0], MutableAnyOrigin, AddressSpace(3), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), _get_index_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), _tile_is_masked[pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), BM if is_a else BN, BK](), 128, WM if is_a else WN, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), linear_idx_type=_get_index_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), masked=_tile_is_masked[pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), BM if is_a else BN, BK]() if _tile_is_masked[pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), BM if is_a else BN, BK]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, SmemBufferDataType, pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), MutableAnyOrigin, AddressSpace(3), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), _get_index_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), False, 128, BM if is_a else BN, BK]()[0], WM if is_a else WN, WK](), alignment=128]

Returns:

LayoutTensor

commit

commit[is_a: Bool](mut self, stage: Int, tile_idx: Int)

Was this page helpful?