IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MatmulConfig

struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]

Static configuration of GPU matmul.

Fields​

  • ​block_tile_shape (IndexList[Int(3)]):
  • ​warp_tile_shape (IndexList[Int(3)]):
  • ​mma_shape (IndexList[Int(3)]):
  • ​num_pipeline_stages (Int):
  • ​num_k_partitions (Int):
  • ​k_group_size (Int):
  • ​num_warp_k_partitions (Int):
  • ​cluster_shape (IndexList[Int(3)]):
  • ​num_consumer (Int):
  • ​partitioned_multicast (Bool):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable, Writable

comptime members​

ACCUM_PRECISION​

comptime ACCUM_PRECISION = 1

accum_type​

comptime accum_type = get_accum_type[a_type]()

OUTPUT_PRECISION​

comptime OUTPUT_PRECISION = 2

split_k_reduction_scheme​

comptime split_k_reduction_scheme = get_defined_int[StringSlice("SPLITK_REDUCTION_SCHEME"), Int(2)]()

split_k_reduction_type​

comptime split_k_reduction_type = c_type if (Int(2) == get_defined_int[StringSlice("SPLITK_REDUCTION_SCHEME"), Int(2)]()) else MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type

Methods​

__init__​

def __init__(*, block_tile_shape: IndexList[Int(3)] = Index[Int, Int, Int](Int(128), Int(128), Int(32)), warp_tile_shape: IndexList[Int(3)] = Index[Int, Int, Int](Int(64), Int(64), Int(32)), mma_shape: IndexList[Int(3)] = get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape: IndexList[Int(3)] = Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages: Int = Int(4), num_k_partitions: Int = Int(1), k_group_size: Int = Int(1), num_warp_k_partitions: Int = Int(1), num_consumer: Int = Int(1), partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel()) -> Self

__eq__​

def __eq__(self, rhs: MatmulConfig) -> Bool

Returns:

Bool

copy_field​

def copy_field(mut self, other: MatmulConfig)

swapAB​

def swapAB(self) -> MatmulConfig[b_type, a_type, c_type, transpose_b]

Returns:

MatmulConfig[b_type, a_type, c_type, transpose_b]

num_warps_m​

def num_warps_m(self) -> Int

Returns:

Int

num_warps_n​

def num_warps_n(self) -> Int

Returns:

Int

num_threads​

def num_threads(self) -> Int

Returns:

Int

shared_mem_usage​

def shared_mem_usage(self) -> Int

Returns:

Int

grid_dim​

def grid_dim(self, m: Int, n: Int) -> IndexList[Int(3)]

Returns:

IndexList[Int(3)]

block_dim​

def block_dim(self) -> IndexList[Int(3)]

Returns:

IndexList[Int(3)]

work_space_size​

def work_space_size(self, M: Int, N: Int) -> Int

Returns:

Int

pdl_level​

def pdl_level(self) -> PDLLevel

Returns:

PDLLevel

write_to​

def write_to(self, mut writer: T)

write_repr_to​

def write_repr_to(self, mut writer: T)

__hash__​

def __hash__[H: Hasher](self, mut hasher: H)

Updates hasher with the underlying bytes.

Parameters:

  • ​H (Hasher): The hasher type.

Args:

  • ​hasher (H): The hasher instance.