Skip to main content

Mojo struct

STMatrixLayout

struct STMatrixLayout[BM: Int, BN: Int, *, num_threads: Int, accum_dtype_size: Int]

Layout for using st_matrix for writing the final accumulator to smem.

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

bits​

comptime bits = (((8 * STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth) * 4) * accum_dtype_size)

bits_per_byte​

comptime bits_per_byte = 8

element_layout​

comptime element_layout = Layout.row_major(1, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth)

elements_per_repeat​

comptime elements_per_repeat = (STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth * 2)

frag_simdwidth​

comptime frag_simdwidth = 2

frag_size​

comptime frag_size = ((BN * 2) // 4)

num_m_tiles​

comptime num_m_tiles = (STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles_total // STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_warpgroups)

num_m_tiles_total​

comptime num_m_tiles_total = ceildiv((2 * BM), 128)

num_row_blocks_per_mma​

comptime num_row_blocks_per_mma = 2

num_warpgroups​

comptime num_warpgroups = ceildiv(num_threads, 128)

repeat​

comptime repeat = (BN // (4 * STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth))

row_of_frags_layout​

comptime row_of_frags_layout = Layout.row_major(STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_size)

TensorType​

comptime TensorType[dtype: DType] = LayoutTensor[dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].element_layout]

Parameters​

thread_cols​

comptime thread_cols = 4

vec_local_layout​

comptime vec_local_layout = Layout(IntTuple(IntTuple(2, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles), IntTuple(STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].repeat), __list_literal__=NoneType(None)), IntTuple(IntTuple(STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_size), IntTuple((2 * STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth)), __list_literal__=NoneType(None)))

Methods​

__init__​

__init__() -> Self