IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

STMatrixLayout

struct STMatrixLayout[BM: Int, BN: Int, *, num_threads: Int, accum_dtype_size: Int]

Layout for using st_matrix for writing the final accumulator to smem.

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

bits​

comptime bits = (Int(64) * accum_dtype_size)

bits_per_byte​

comptime bits_per_byte = 8

element_layout​

comptime element_layout = Layout.row_major(Int(1), Int(2))

elements_per_repeat​

comptime elements_per_repeat = Int(4)

frag_simdwidth​

comptime frag_simdwidth = Int(2)

frag_size​

comptime frag_size = ((BN * Int(2)) // Int(4))

num_m_tiles​

comptime num_m_tiles = (STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles_total // STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_warpgroups)

num_m_tiles_total​

comptime num_m_tiles_total = ceildiv((Int(2) * BM), Int(128))

num_row_blocks_per_mma​

comptime num_row_blocks_per_mma = 2

num_warpgroups​

comptime num_warpgroups = ceildiv(num_threads, Int(128))

repeat​

comptime repeat = (BN // Int(8))

row_of_frags_layout​

comptime row_of_frags_layout = Layout.row_major(STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_size)

TensorType​

comptime TensorType[dtype: DType] = LayoutTensor[dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].element_layout]

Parameters​

thread_cols​

comptime thread_cols = 4

vec_local_layout​

comptime vec_local_layout = Layout(IntTuple(IntTuple(Int(2), STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles), IntTuple((BN // Int(8))), __list_literal__=NoneType(None)), IntTuple(IntTuple(Int(2), STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_size), IntTuple(Int(4)), __list_literal__=NoneType(None)))

Methods​

__init__​

def __init__() -> Self