Skip to main content

Mojo struct

STMatrixLayout

struct STMatrixLayout[BM: Int, BN: Int, *, num_threads: Int, accum_dtype_size: Int]

Layout for using st_matrix for writing the final accumulator to smem.

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

bits

comptime bits = (((8 * STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth) * 4) * accum_dtype_size)

bits_per_byte

comptime bits_per_byte = 8

element_layout

comptime element_layout = Layout.row_major(VariadicList(1, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth))

elements_per_repeat

comptime elements_per_repeat = (STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth * 2)

frag_simdwidth

comptime frag_simdwidth = 2

frag_size

comptime frag_size = ((BN * 2) // 4)

num_m_tiles

comptime num_m_tiles = (STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles_total // STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_warpgroups)

num_m_tiles_total

comptime num_m_tiles_total = ceildiv((2 * BM), 128)

num_row_blocks_per_mma

comptime num_row_blocks_per_mma = 2

num_warpgroups

comptime num_warpgroups = ceildiv(num_threads, 128)

repeat

comptime repeat = (BN // (4 * STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth))

row_of_frags_layout

comptime row_of_frags_layout = Layout.row_major(VariadicList(STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_size))

TensorType

comptime TensorType[dtype: DType] = LayoutTensor[dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].element_layout]

Parameters

thread_cols

comptime thread_cols = 4

vec_local_layout

comptime vec_local_layout = Layout(IntTuple(VariadicList(IntTuple(VariadicList(2, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles)), IntTuple(STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].repeat)), Tuple()), IntTuple(VariadicList(IntTuple(VariadicList(STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_size)), IntTuple((2 * STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_simdwidth))), Tuple()))

Methods

__init__

__init__() -> Self

Was this page helpful?