Mojo struct
TileBuffers
struct TileBuffers[in_type: DType, a_layout: Layout, b_layout: Layout, //, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, MMA_M: Int, MMA_K: Int, num_threads: Int, alignment: Int, enable_swizzle: Bool, load_width: Int, loading_warps: Int = 8]
Double-buffered LDS tiles and TileLoaders for ping-pong matmul.
a_layout and b_layout are infer-only parameters (note //), automatically
extracted from the input tensors passed to init. K is derived as an
comptime from a_layout.shape[1].
Fields
- a_mma_tiles (
Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair]): - b_mma_tiles (
Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair]): - a_load_tiles (
Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair]): - b_load_tiles (
Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair]): - loader_a (
TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].ATileLoader): - loader_b (
TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BTileLoader): - warp_id_m (
Int):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
a_half_tile_layout
comptime a_half_tile_layout = Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK)
AHalfTile
comptime AHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]
AHalfTilePair
comptime AHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile]
AMmaTile
comptime AMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK](), alignment=alignment]
AMmaTilePair
comptime AMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile]
ATileLoader
comptime ATileLoader = TileLoaderLDS[in_type, a_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].a_half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]
b_half_tile_layout
comptime b_half_tile_layout = Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK)
BHalfTile
comptime BHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]
BHalfTilePair
comptime BHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile]
BMmaTile
comptime BMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() or _tile_is_masked[LayoutTensor._compute_tile_layout[WN, BK]()[0], TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK](), alignment=alignment]
BMmaTilePair
comptime BMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile]
BTileLoader
comptime BTileLoader = TileLoaderLDS[in_type, b_layout, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].b_half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_fp8_row_major]
byte_swizzle
comptime byte_swizzle = Optional(Swizzle(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_log_tile, TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_base, 4)) if enable_swizzle else Optional()
elem_size
comptime elem_size = size_of[in_type]()
elements_per_warp
comptime elements_per_warp = (WARP_SIZE * load_width)
frag_bytes
comptime frag_bytes = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].lds_frag_width * size_of[in_type]())
half_BM
comptime half_BM = WM
half_BN
comptime half_BN = (BN // 2)
HalfTile
comptime HalfTile[rows: Int] = LayoutTensor[in_type, Layout.row_major(rows, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]
Parameters
- rows (
Int):
K
comptime K = a_layout.shape[1].value()
lds_frag_width
comptime lds_frag_width = 16 if TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].use_split_k else TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_frag_width
loading_threads
comptime loading_threads = (loading_warps * WARP_SIZE)
loads_per_row
comptime loads_per_row = (BK // load_width)
mma_frag_width
comptime mma_frag_width = ((MMA_M * MMA_K) // WARP_SIZE)
mma_tile_m
comptime mma_tile_m = (WM // 2)
mma_tile_n
comptime mma_tile_n = (WN // 2)
rows_per_load_iteration
comptime rows_per_load_iteration = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loading_threads // TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)
rows_per_warp
comptime rows_per_warp = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].elements_per_warp // BK)
SMemTile
comptime SMemTile[rows: Int, cols: Int] = LayoutTensor[in_type, Layout.row_major(rows, cols), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]
Parameters
swizzle_base
comptime swizzle_base = log2_floor(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].frag_bytes) if in_type.is_float8() else (log2_floor((TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_subtile_cols // 2)) + log2_floor(TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].elem_size))
swizzle_log_tile
comptime swizzle_log_tile = (log2_floor((MMA_K // 32)) + 1)
swizzle_shift
comptime swizzle_shift = 4
swizzle_subtile_cols
comptime swizzle_subtile_cols = (4 * load_width)
total_warps
comptime total_warps = 8
use_fp8_row_major
comptime use_fp8_row_major = in_type.is_float8()
use_split_k
comptime use_split_k = in_type.is_float8() and (MMA_K == 128)
vmcnt_per_load_a
comptime vmcnt_per_load_a = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM // TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)
vmcnt_per_load_ab
comptime vmcnt_per_load_ab = (TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_a + TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_b)
vmcnt_per_load_b
comptime vmcnt_per_load_b = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, MMA_M, MMA_K, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)
Methods
__init__
__init__(out self, a: LayoutTensor[in_type, a_layout, a.origin, address_space=a.address_space, element_layout=a.element_layout, layout_int_type=a.layout_int_type, linear_idx_type=a.linear_idx_type, masked=a.masked, alignment=a.alignment], b: LayoutTensor[b.dtype, b_layout, b.origin, address_space=b.address_space, element_layout=b.element_layout, layout_int_type=b.layout_int_type, linear_idx_type=b.linear_idx_type, masked=b.masked, alignment=b.alignment], block_row: Int, block_col: Int, warp_id: Int, warp_id_m: Int, warp_id_n: Int, lane_id: Int)
Initialize LDS tiles and loaders. Layouts inferred from a and b tensors.
load_a
load_a[stage: Int, which: Int](self, *, k: Int)
Load A[stage][which] from global to LDS using all 8 warps.
load_b
load_b[stage: Int, which: Int](self, *, k: Int)
Load B[stage][which] from global to LDS using all 8 warps.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!