Skip to main content

Mojo struct

TileBuffers

struct TileBuffers[in_type: DType, a_layout: Layout, b_layout: Layout, //, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, MMA_M: Int, MMA_K: Int, num_threads: Int, alignment: Int, enable_swizzle: Bool, load_width: Int, loading_warps: Int = 8]

Double-buffered LDS tiles and TileLoaders for ping-pong matmul.

a_layout and b_layout are infer-only parameters (note //), automatically extracted from the input tensors passed to init. K is derived as an comptime from a_layout.shape[1].

Fields

  • a_mma_tiles (Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair]):
  • b_mma_tiles (Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair]):
  • a_load_tiles (Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair]):
  • b_load_tiles (Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair]):
  • loader_a (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].ATileLoader):
  • loader_b (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BTileLoader):
  • warp_id_m (Int):

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

a_half_tile_layout

comptime a_half_tile_layout = Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK)

AHalfTile

comptime AHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]

AHalfTilePair

comptime AHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile]

AMmaTile

comptime AMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK](), alignment=alignment]

AMmaTilePair

comptime AMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile]

ATileLoader

comptime ATileLoader = TileLoaderLDS[in_type, a_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].a_half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]

b_half_tile_layout

comptime b_half_tile_layout = Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK)

BHalfTile

comptime BHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]

BHalfTilePair

comptime BHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile]

BMmaTile

comptime BMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() or _tile_is_masked[LayoutTensor._compute_tile_layout[WN, BK]()[0], TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK](), alignment=alignment]

BMmaTilePair

comptime BMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile]

BTileLoader

comptime BTileLoader = TileLoaderLDS[in_type, b_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].b_half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]

byte_swizzle

comptime byte_swizzle = Optional(Swizzle(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_log_tile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_base, 4)) if enable_swizzle else Optional()

elem_size

comptime elem_size = size_of[in_type]()

elements_per_warp

comptime elements_per_warp = (WARP_SIZE * load_width)

frag_bytes

comptime frag_bytes = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].lds_frag_width * size_of[in_type]())

half_BM

comptime half_BM = WM

half_BN

comptime half_BN = (BN // 2)

HalfTile

comptime HalfTile[rows: Int] = LayoutTensor[in_type, Layout.row_major(rows, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]

Parameters

  • rows (Int):

K

comptime K = a_layout.shape[1].value()

lds_frag_width

comptime lds_frag_width = 16 if TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_split_k else TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_frag_width

loading_threads

comptime loading_threads = (loading_warps * WARP_SIZE)

loads_per_row

comptime loads_per_row = (BK // load_width)

mma_frag_width

comptime mma_frag_width = ((MMA_M * MMA_K) // WARP_SIZE)

mma_tile_m

comptime mma_tile_m = (WM // 2)

mma_tile_n

comptime mma_tile_n = (WN // 2)

rows_per_load_iteration

comptime rows_per_load_iteration = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loading_threads // TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)

rows_per_warp

comptime rows_per_warp = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].elements_per_warp // BK)

SMemTile

comptime SMemTile[rows: Int, cols: Int] = LayoutTensor[in_type, Layout.row_major(rows, cols), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]

Parameters

  • rows (Int):
  • cols (Int):

swizzle_base

comptime swizzle_base = log2_floor(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].frag_bytes) if in_type.is_float8() else (log2_floor((TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_subtile_cols // 2)) + log2_floor(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].elem_size))

swizzle_log_tile

comptime swizzle_log_tile = (log2_floor((MMA_K // 32)) + 1)

swizzle_shift

comptime swizzle_shift = 4

swizzle_subtile_cols

comptime swizzle_subtile_cols = (4 * load_width)

total_warps

comptime total_warps = 8

use_fp8_row_major

comptime use_fp8_row_major = in_type.is_float8()

use_split_k

comptime use_split_k = in_type.is_float8() and (MMA_K == 128)

vmcnt_per_load_a

comptime vmcnt_per_load_a = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM // TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)

vmcnt_per_load_ab

comptime vmcnt_per_load_ab = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_a + TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_b)

vmcnt_per_load_b

comptime vmcnt_per_load_b = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)

Methods

__init__

__init__(out self, a: LayoutTensor[in_type, a_layout, a.origin, address_space=a.address_space, element_layout=a.element_layout, layout_int_type=a.layout_int_type, linear_idx_type=a.linear_idx_type, masked=a.masked, alignment=a.alignment], b: LayoutTensor[b.dtype, b_layout, b.origin, address_space=b.address_space, element_layout=b.element_layout, layout_int_type=b.layout_int_type, linear_idx_type=b.linear_idx_type, masked=b.masked, alignment=b.alignment], block_row: Int, block_col: Int, warp_id: Int, warp_id_m: Int, warp_id_n: Int, lane_id: Int)

Initialize LDS tiles and loaders. Layouts inferred from a and b tensors.

load_a

load_a[stage: Int, which: Int](self, *, k: Int)

Load A[stage][which] from global to LDS using all 8 warps.

load_b

load_b[stage: Int, which: Int](self, *, k: Int)

Load B[stage][which] from global to LDS using all 8 warps.

Was this page helpful?