Mojo struct
RingBuffer
struct RingBuffer[pipeline_stages: Int, a_tile_layout: Layout, b_tile_layout: Layout, TileTypeA: DType, TileTypeB: DType, WM: Int, WN: Int, WK: Int, warps_per_block_m: Int, warps_per_block_n: Int]
Simulates A TMA barriers on AMD. This pipeline can be used for both multistage (double buffering) where all threads are producers and consumers, and for warp specialized pipelines.
Fields
- barrier_a (
LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]
): - barrier_b (
LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]
): - smem_tile_a (
SharedMemoryBuffer[TileTypeA, a_tile_layout, pipeline_stages, WM, WK]
): - smem_tile_b (
SharedMemoryBuffer[TileTypeB, b_tile_layout, pipeline_stages, WN, WK]
):
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = SharedMemoryBuffer[TileTypeB, b_tile_layout, pipeline_stages, WN, WK].__del__is_trivial if SharedMemoryBuffer[TileTypeA, a_tile_layout, pipeline_stages, WM, WK].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else SharedMemoryBuffer[TileTypeA, a_tile_layout, pipeline_stages, WM, WK].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial
BarrierTensorType
alias BarrierTensorType[warp_count: Int] = LayoutTensor[DType.int32, Layout.row_major(warp_count, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]
Parameters
- warp_count (
Int
):
SharedMemoryBufferType
alias SharedMemoryBufferType[is_a: Bool] = SharedMemoryBuffer[TileTypeA if is_a else TileTypeB, a_tile_layout if is_a else b_tile_layout, pipeline_stages, WM if is_a else WN, WK]
Parameters
- is_a (
Bool
):
Methods
__init__
__init__(out self, smem_tile_a: SharedMemoryBuffer[TileTypeA, a_tile_layout, pipeline_stages, WM, WK], smem_tile_b: SharedMemoryBuffer[TileTypeB, b_tile_layout, pipeline_stages, WN, WK])
put_tile
put_tile[is_a: Bool](mut self, mut phase: Int, stage: Int, tile_idx: Int) -> LayoutTensor[TileTypeA if is_a else TileTypeB, LayoutTensor._compute_tile_layout[True, TileTypeA if is_a else TileTypeB, Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), _get_index_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), False, 128, WM if is_a else WN, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), linear_idx_type=_get_index_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), masked=_tile_is_masked[Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, WM if is_a else WN, WK](), alignment=128]
Returns:
get_tile
get_tile[is_a: Bool](mut self, mut phase: Int, stage: Int, tile_idx: Int) -> LayoutTensor[TileTypeA if is_a else TileTypeB, LayoutTensor._compute_tile_layout[True, TileTypeA if is_a else TileTypeB, Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), _get_index_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), False, 128, WM if is_a else WN, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), linear_idx_type=_get_index_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), masked=_tile_is_masked[Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, WM if is_a else WN, WK](), alignment=128]
Returns:
commit
commit[is_a: Bool](mut self, stage: Int, tile_idx: Int)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!