For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
STMatrixLayout
struct STMatrixLayout[BM: Int, BN: Int, *, num_threads: Int, accum_dtype_size: Int]
Layout for using st_matrix for writing the final accumulator to smem.
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
bitsβ
comptime bits = (Int(64) * accum_dtype_size)
bits_per_byteβ
comptime bits_per_byte = 8
element_layoutβ
comptime element_layout = Layout.row_major(Int(1), Int(2))
elements_per_repeatβ
comptime elements_per_repeat = Int(4)
frag_simdwidthβ
comptime frag_simdwidth = Int(2)
frag_sizeβ
comptime frag_size = ((BN * Int(2)) // Int(4))
num_m_tilesβ
comptime num_m_tiles = (STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles_total // STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_warpgroups)
num_m_tiles_totalβ
comptime num_m_tiles_total = ceildiv((Int(2) * BM), Int(128))
num_row_blocks_per_mmaβ
comptime num_row_blocks_per_mma = 2
num_warpgroupsβ
comptime num_warpgroups = ceildiv(num_threads, Int(128))
repeatβ
comptime repeat = (BN // Int(8))
row_of_frags_layoutβ
comptime row_of_frags_layout = Layout.row_major(STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_size)
TensorTypeβ
comptime TensorType[dtype: DType] = LayoutTensor[dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].element_layout]
Parametersβ
- βdtype (
DType):
thread_colsβ
comptime thread_cols = 4
vec_local_layoutβ
comptime vec_local_layout = Layout(IntTuple(IntTuple(Int(2), STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].num_m_tiles), IntTuple((BN // Int(8))), __list_literal__=NoneType(None)), IntTuple(IntTuple(Int(2), STMatrixLayout[BM, BN, num_threads=num_threads, accum_dtype_size=accum_dtype_size].frag_size), IntTuple(Int(4)), __list_literal__=NoneType(None)))
Methodsβ
__init__β
def __init__() -> Self
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!