Mojo struct
TileBuffers
struct TileBuffers[in_type: DType, a_layout: Layout, b_layout: Layout, //, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_threads: Int, alignment: Int, enable_swizzle: Bool, load_width: Int, loading_warps: Int = 8]
Double-buffered LDS tiles and TileLoaders for ping-pong matmul.
a_layout and b_layout are infer-only parameters (note //), automatically
extracted from the input tensors passed to init. K is derived as an
alias from a_layout.shape[1].
Fields
- a_mma_tiles (
Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair]): - b_mma_tiles (
Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair]): - a_load_tiles (
Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair]): - b_load_tiles (
Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair]): - loader_a (
TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].ATileLoader): - loader_b (
TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BTileLoader): - warp_id_m (
Int): - k_offset (
Int): - warp_shift_rows (
Int):
Implemented traits
AnyType,
UnknownDestructibility
comptime members
__del__is_trivial
comptime __del__is_trivial = False
AHalfTile
comptime AHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]
AHalfTilePair
comptime AHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile]
AMmaTile
comptime AMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), False, alignment, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK](), alignment=alignment]
AMmaTilePair
comptime AMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile]
ATileLoader
comptime ATileLoader = TileLoaderLDS[in_type, a_layout, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width]
BHalfTile
comptime BHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]
BHalfTilePair
comptime BHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile]
BMmaTile
comptime BMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), False, alignment, WN, BK]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK](), alignment, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() if _tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), False, alignment, WN, BK]()[0], TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK](), alignment=alignment]
BMmaTilePair
comptime BMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile]
BTileLoader
comptime BTileLoader = TileLoaderLDS[in_type, b_layout, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width]
byte_swizzle
comptime byte_swizzle = OptionalReg[Swizzle](Swizzle(1, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_byte_base, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_shift)) if enable_swizzle else OptionalReg[Swizzle]()
elem_size
comptime elem_size = size_of[in_type]()
elements_per_warp
comptime elements_per_warp = (WARP_SIZE * load_width)
half_BM
comptime half_BM = (BM // 2)
half_BN
comptime half_BN = (BN // 2)
half_tile_layout
comptime half_tile_layout = Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK)
HalfTile
comptime HalfTile[rows: Int] = LayoutTensor[in_type, Layout.row_major(rows, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]
Parameters
- rows (
Int):
K
comptime K = a_layout.shape[1].value()
loading_threads
comptime loading_threads = (loading_warps * WARP_SIZE)
loads_per_row
comptime loads_per_row = (BK // load_width)
mma_tile_m
comptime mma_tile_m = (WM // 2)
mma_tile_n
comptime mma_tile_n = (WN // 2)
rows_per_iter_4warp
comptime rows_per_iter_4warp = ((4 * WARP_SIZE) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)
rows_per_load_iteration
comptime rows_per_load_iteration = (TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].loading_threads // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)
rows_per_warp
comptime rows_per_warp = (TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].elements_per_warp // BK)
smem_ptr
comptime smem_ptr = LegacyUnsafePointer[Scalar[in_type], address_space=AddressSpace.SHARED]
SMemTile
comptime SMemTile[rows: Int, cols: Int] = LayoutTensor[in_type, Layout.row_major(rows, cols), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]
Parameters
swizzle_byte_base
comptime swizzle_byte_base = (TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_elem_base + log2_floor(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].elem_size))
swizzle_elem_base
comptime swizzle_elem_base = log2_floor((TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_subtile_cols // 2))
swizzle_shift
comptime swizzle_shift = log2_floor(16)
swizzle_subtile_cols
comptime swizzle_subtile_cols = (4 * load_width)
swizzle_subtile_rows
comptime swizzle_subtile_rows = 16
TileLoader
comptime TileLoader[src_layout: Layout] = TileLoaderLDS[in_type, src_layout, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width]
Parameters
- src_layout (
Layout):
total_warps
comptime total_warps = 8
vmcnt_per_load_a
comptime vmcnt_per_load_a = ((BM // 2) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)
vmcnt_per_load_a_4warp
comptime vmcnt_per_load_a_4warp = ((BM // 2) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_iter_4warp)
vmcnt_per_load_ab
comptime vmcnt_per_load_ab = (TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_a + TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_b)
vmcnt_per_load_b
comptime vmcnt_per_load_b = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)
vmcnt_per_load_b_4warp
comptime vmcnt_per_load_b_4warp = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_iter_4warp)
Methods
__init__
__init__(out self, a: LayoutTensor[in_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], block_row: Int, block_col: Int, warp_id: Int, warp_id_m: Int, warp_id_n: Int, lane_id: Int)
Initialize LDS tiles and loaders. Layouts inferred from a and b tensors.
advance_k
advance_k(mut self)
Advance k_offset by BK for the next K iteration.
load_a
load_a[stage: Int, which: Int](self)
Load A[stage][which] using 8 warps.
load_b
load_b[stage: Int, which: Int](self)
Load B[stage][which] using 8 warps.
load_a_as_group
load_a_as_group[stage: Int, target_group: Int](self, caller_group: Int)
Load A[stage][target_group] using 4 warps. Only executes if caller_group == target_group.
load_b_as_group
load_b_as_group[stage: Int, which: Int](self, caller_group: Int, loading_group: Int)
Load B[stage][which] using 4 warps. Only executes if caller_group == loading_group.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!