Skip to main content

Mojo struct

TileBuffers

struct TileBuffers[in_type: DType, a_layout: Layout, b_layout: Layout, //, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, MMA_K: Int, num_threads: Int, alignment: Int, enable_swizzle: Bool, load_width: Int, loading_warps: Int = 8]

Double-buffered LDS tiles and TileLoaders for ping-pong matmul.

a_layout and b_layout are infer-only parameters (note //), automatically extracted from the input tensors passed to init. K is derived as an comptime from a_layout.shape[1].

Fields

  • a_mma_tiles (Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair]):
  • b_mma_tiles (Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair]):
  • a_load_tiles (Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair]):
  • b_load_tiles (Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair]):
  • loader_a (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].ATileLoader):
  • loader_b (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BTileLoader):
  • warp_id_m (Int):
  • warp_shift_rows (Int):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

__del__is_trivial

comptime __del__is_trivial = False

AHalfTile

comptime AHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]

AHalfTilePair

comptime AHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile]

AMmaTile

comptime AMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), False, alignment, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK](), alignment=alignment]

AMmaTilePair

comptime AMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile]

ATileLoader

comptime ATileLoader = TileLoaderLDS[in_type, a_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]

BHalfTile

comptime BHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]

BHalfTilePair

comptime BHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile]

BMmaTile

comptime BMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), False, alignment, WN, BK]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK](), alignment, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() if _tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), False, alignment, WN, BK]()[0], TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK](), alignment=alignment]

BMmaTilePair

comptime BMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile]

BTileLoader

comptime BTileLoader = TileLoaderLDS[in_type, b_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]

byte_swizzle

comptime byte_swizzle = OptionalReg[Swizzle](Swizzle(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_log_tile, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_base, 4)) if enable_swizzle else OptionalReg[Swizzle]()

elem_size

comptime elem_size = size_of[in_type]()

elements_per_warp

comptime elements_per_warp = (WARP_SIZE * load_width)

frag_bytes

comptime frag_bytes = (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].lds_frag_width * size_of[in_type]())

half_BM

comptime half_BM = (BM // 2)

half_BN

comptime half_BN = (BN // 2)

half_tile_layout

comptime half_tile_layout = Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK)

HalfTile

comptime HalfTile[rows: Int] = LayoutTensor[in_type, Layout.row_major(rows, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]

Parameters

  • rows (Int):

K

comptime K = a_layout.shape[1].value()

lds_frag_width

comptime lds_frag_width = 16 if (eq MMA_K._mlir_value, 128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_frag_width

loading_threads

comptime loading_threads = (loading_warps * WARP_SIZE)

loads_per_row

comptime loads_per_row = (BK // load_width)

mma_frag_width

comptime mma_frag_width = ((16 * MMA_K) // WARP_SIZE)

mma_tile_m

comptime mma_tile_m = (WM // 2)

mma_tile_n

comptime mma_tile_n = (WN // 2)

rows_per_iter_4warp

comptime rows_per_iter_4warp = ((4 * WARP_SIZE) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)

rows_per_load_iteration

comptime rows_per_load_iteration = (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loading_threads // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)

rows_per_warp

comptime rows_per_warp = (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].elements_per_warp // BK)

smem_ptr

comptime smem_ptr = LegacyUnsafePointer[Scalar[in_type], address_space=AddressSpace.SHARED]

SMemTile

comptime SMemTile[rows: Int, cols: Int] = LayoutTensor[in_type, Layout.row_major(rows, cols), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]

Parameters

  • rows (Int):
  • cols (Int):

swizzle_base

comptime swizzle_base = log2_floor(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].frag_bytes) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (log2_floor((TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_subtile_cols // 2)) + log2_floor(TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].elem_size))

swizzle_log_tile

comptime swizzle_log_tile = (log2_floor((MMA_K // 32)) + 1)

swizzle_shift

comptime swizzle_shift = 4

swizzle_subtile_cols

comptime swizzle_subtile_cols = (4 * load_width)

TileLoader

comptime TileLoader[src_layout: Layout] = TileLoaderLDS[in_type, src_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]

Parameters

total_warps

comptime total_warps = 8

use_fp8_row_major

comptime use_fp8_row_major = (in_type == DType.float8_e4m3fn)

use_split_k

comptime use_split_k = (MMA_K == 128) if (eq #pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@std::@builtin::@dtype::@DType> in_type, "_mlir_value">>, 75) else (in_type == DType.float8_e4m3fn)

vmcnt_per_load_a

comptime vmcnt_per_load_a = ((BM // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)

vmcnt_per_load_a_4warp

comptime vmcnt_per_load_a_4warp = ((BM // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_iter_4warp)

vmcnt_per_load_ab

comptime vmcnt_per_load_ab = (TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_a + TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_b)

vmcnt_per_load_b

comptime vmcnt_per_load_b = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)

vmcnt_per_load_b_4warp

comptime vmcnt_per_load_b_4warp = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_iter_4warp)

Methods

__init__

__init__(out self, a: LayoutTensor[in_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], block_row: Int, block_col: Int, warp_id: Int, warp_id_m: Int, warp_id_n: Int, lane_id: Int)

Initialize LDS tiles and loaders. Layouts inferred from a and b tensors.

load_a

load_a[stage: Int, which: Int, *, k: Int](self)

Load A[stage][which] from global to LDS using all 8 warps.

load_b

load_b[stage: Int, which: Int, *, k: Int](self)

Load B[stage][which] from global to LDS using all 8 warps.

load_a_as_group

load_a_as_group[stage: Int, target_group: Int, *, k: Int](self, caller_group: Int)

Load A[stage][target_group] from global to LDS using 4 warps.

load_b_as_group

load_b_as_group[stage: Int, which: Int, *, k: Int](self, caller_group: Int, loading_group: Int)

Load B[stage][which] from global to LDS using 4 warps.

Was this page helpful?