Mojo struct
RingBuffer
struct RingBuffer[pipeline_stages: Int, a_buffer_layout: Layout, b_buffer_layout: Layout, SmemBufferDataType: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, //, SmemBufferTypeA: AnyStruct[SharedMemoryBuffer[SmemBufferDataType, a_buffer_layout, pipeline_stages, BM, BK, WM, WK]], SmemBufferTypeB: AnyStruct[SharedMemoryBuffer[SmemBufferDataType, b_buffer_layout, pipeline_stages, BN, BK, WN, WK]], consumer_warps: Int]
Manages access to shared memory tiles using barriers based in shared memory.
Fields
- barrier_a (
LayoutTensor[DType.int32, Layout.row_major((BM // WM), (consumer_warps * pipeline_stages)), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]
): - barrier_b (
LayoutTensor[DType.int32, Layout.row_major((BN // WN), (consumer_warps * pipeline_stages)), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]
): - smem_buffer_a (
SmemBufferTypeA
): - smem_buffer_b (
SmemBufferTypeB
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = SmemBufferTypeB.__del__is_trivial if SmemBufferTypeA.__del__is_trivial else SmemBufferTypeA.__del__is_trivial
BarrierTensorType
alias BarrierTensorType[warp_tile_count: Int] = LayoutTensor[DType.int32, Layout.row_major(warp_tile_count, (consumer_warps * pipeline_stages)), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]
Parameters
- warp_tile_count (
Int
):
SharedMemoryBufferType
alias SharedMemoryBufferType[is_a: Bool] = SharedMemoryBuffer[SmemBufferDataType, a_buffer_layout if is_a else b_buffer_layout, pipeline_stages, BM if is_a else BN, BK, WM if is_a else WN, WK]
Parameters
- is_a (
Bool
):
warps_per_block_m
alias warps_per_block_m = (BM // WM)
warps_per_block_n
alias warps_per_block_n = (BN // WN)
Methods
__init__
__init__(out self, smem_buffer_a: SharedMemoryBuffer[SmemBufferDataType, a_buffer_layout, pipeline_stages, BM, BK, WM, WK], smem_buffer_b: SharedMemoryBuffer[SmemBufferDataType, b_buffer_layout, pipeline_stages, BN, BK, WN, WK])
await_shared_memory_warp_tile
await_shared_memory_warp_tile[is_a: Bool, is_producer: Bool](mut self, mut phase: Int, stage: Int, tile_idx: Int) -> LayoutTensor[SmemBufferDataType, LayoutTensor._compute_tile_layout[True, SmemBufferDataType, LayoutTensor._compute_tile_layout[True, SmemBufferDataType, pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), MutableAnyOrigin, AddressSpace(3), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), _get_index_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), False, 128, BM if is_a else BN, BK]()[0], MutableAnyOrigin, AddressSpace(3), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), _get_index_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), _tile_is_masked[pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), BM if is_a else BN, BK](), 128, WM if is_a else WN, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), linear_idx_type=_get_index_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), masked=_tile_is_masked[pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), BM if is_a else BN, BK]() if _tile_is_masked[pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), BM if is_a else BN, BK]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, SmemBufferDataType, pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), MutableAnyOrigin, AddressSpace(3), Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), _get_index_type(pipeline_layout[a_buffer_layout if is_a else b_buffer_layout, pipeline_stages](), AddressSpace(3)), False, 128, BM if is_a else BN, BK]()[0], WM if is_a else WN, WK](), alignment=128]
Returns:
commit
commit[is_a: Bool](mut self, stage: Int, tile_idx: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!