Skip to main content

Mojo struct

RingBuffer

struct RingBuffer[pipeline_stages: Int, a_tile_layout: Layout, b_tile_layout: Layout, TileTypeA: DType, TileTypeB: DType, WM: Int, WN: Int, WK: Int, warps_per_block_m: Int, warps_per_block_n: Int]

Simulates A TMA barriers on AMD. This pipeline can be used for both multistage (double buffering) where all threads are producers and consumers, and for warp specialized pipelines.

Fields

  • barrier_a (LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]):
  • barrier_b (LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]):
  • smem_tile_a (SharedMemoryBuffer[TileTypeA, a_tile_layout, pipeline_stages, WM, WK]):
  • smem_tile_b (SharedMemoryBuffer[TileTypeB, b_tile_layout, pipeline_stages, WN, WK]):

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = SharedMemoryBuffer[TileTypeB, b_tile_layout, pipeline_stages, WN, WK].__del__is_trivial if SharedMemoryBuffer[TileTypeA, a_tile_layout, pipeline_stages, WM, WK].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else SharedMemoryBuffer[TileTypeA, a_tile_layout, pipeline_stages, WM, WK].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_n, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial if LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial else LayoutTensor[DType.int32, Layout.row_major(warps_per_block_m, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32].__del__is_trivial

BarrierTensorType

alias BarrierTensorType[warp_count: Int] = LayoutTensor[DType.int32, Layout.row_major(warp_count, pipeline_stages), MutableAnyOrigin, address_space=AddressSpace(3), alignment=32]

Parameters

  • warp_count (Int):

SharedMemoryBufferType

alias SharedMemoryBufferType[is_a: Bool] = SharedMemoryBuffer[TileTypeA if is_a else TileTypeB, a_tile_layout if is_a else b_tile_layout, pipeline_stages, WM if is_a else WN, WK]

Parameters

Methods

__init__

__init__(out self, smem_tile_a: SharedMemoryBuffer[TileTypeA, a_tile_layout, pipeline_stages, WM, WK], smem_tile_b: SharedMemoryBuffer[TileTypeB, b_tile_layout, pipeline_stages, WN, WK])

put_tile

put_tile[is_a: Bool](mut self, mut phase: Int, stage: Int, tile_idx: Int) -> LayoutTensor[TileTypeA if is_a else TileTypeB, LayoutTensor._compute_tile_layout[True, TileTypeA if is_a else TileTypeB, Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), _get_index_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), False, 128, WM if is_a else WN, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), linear_idx_type=_get_index_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), masked=_tile_is_masked[Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, WM if is_a else WN, WK](), alignment=128]

Returns:

LayoutTensor

get_tile

get_tile[is_a: Bool](mut self, mut phase: Int, stage: Int, tile_idx: Int) -> LayoutTensor[TileTypeA if is_a else TileTypeB, LayoutTensor._compute_tile_layout[True, TileTypeA if is_a else TileTypeB, Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, MutableAnyOrigin, AddressSpace(3), Layout.__init__(IntTuple[__origin_of()](1), IntTuple[__origin_of()](1)), _get_layout_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), _get_index_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), False, 128, WM if is_a else WN, WK]()[0], MutableAnyOrigin, address_space=AddressSpace(3), layout_int_type=_get_layout_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), linear_idx_type=_get_index_type(Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, AddressSpace(3)), masked=_tile_is_masked[Layout.__init__(IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.shape[1], a_tile_layout if is_a else b_tile_layout.shape[2], Tuple[]()), IntTuple.__init__[__origin_of()](a_tile_layout if is_a else b_tile_layout.stride[1], a_tile_layout if is_a else b_tile_layout.stride[2], Tuple[]())) if ((a_tile_layout if is_a else b_tile_layout.rank() == 2) ^ True) else a_tile_layout if is_a else b_tile_layout, WM if is_a else WN, WK](), alignment=128]

Returns:

LayoutTensor

commit

commit[is_a: Bool](mut self, stage: Int, tile_idx: Int)

Was this page helpful?