Skip to main content

Mojo struct

MatmulConfig

struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]

Static configuration of GPU matmul.

Fields​

  • ​block_tile_shape (IndexList[3]):
  • ​warp_tile_shape (IndexList[3]):
  • ​mma_shape (IndexList[3]):
  • ​num_pipeline_stages (Int):
  • ​num_k_partitions (Int):
  • ​k_group_size (Int):
  • ​num_warp_k_partitions (Int):
  • ​cluster_shape (IndexList[3]):
  • ​num_consumer (Int):
  • ​partitioned_multicast (Bool):

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable, Writable

comptime members​

ACCUM_PRECISION​

comptime ACCUM_PRECISION = 1

accum_type​

comptime accum_type = get_accum_type[a_type]()

OUTPUT_PRECISION​

comptime OUTPUT_PRECISION = 2

split_k_reduction_scheme​

comptime split_k_reduction_scheme = get_defined_int[StringSlice("SPLITK_REDUCTION_SCHEME"), 2]()

split_k_reduction_type​

comptime split_k_reduction_type = c_type if (2 == MatmulConfig[a_type, b_type, c_type, transpose_b].split_k_reduction_scheme) else MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type

Methods​

__init__​

__init__(*, block_tile_shape: IndexList[3] = Index[Int, Int, Int](128, 128, 32), warp_tile_shape: IndexList[3] = Index[Int, Int, Int](64, 64, 32), mma_shape: IndexList[3] = get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index[Int, Int, Int](1, 1, 1), num_pipeline_stages: Int = 4, num_k_partitions: Int = 1, k_group_size: Int = 1, num_warp_k_partitions: Int = 1, num_consumer: Int = 1, partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel()) -> Self

__eq__​

__eq__(self, rhs: MatmulConfig) -> Bool

Returns:

Bool

copy_field​

copy_field(mut self, other: MatmulConfig)

swapAB​

swapAB(self) -> MatmulConfig[b_type, a_type, c_type, transpose_b]

Returns:

MatmulConfig[b_type, a_type, c_type, transpose_b]

num_warps_m​

num_warps_m(self) -> Int

Returns:

Int

num_warps_n​

num_warps_n(self) -> Int

Returns:

Int

num_threads​

num_threads(self) -> Int

Returns:

Int

shared_mem_usage​

shared_mem_usage(self) -> Int

Returns:

Int

grid_dim​

grid_dim(self, m: Int, n: Int) -> IndexList[3]

Returns:

IndexList[3]

block_dim​

block_dim(self) -> IndexList[3]

Returns:

IndexList[3]

work_space_size​

work_space_size(self, M: Int, N: Int) -> Int

Returns:

Int

pdl_level​

pdl_level(self) -> PDLLevel

Returns:

PDLLevel

write_to​

write_to(self, mut writer: T)

write_repr_to​

write_repr_to(self, mut writer: T)

__hash__​

__hash__[H: Hasher](self, mut hasher: H)

Updates hasher with the underlying bytes.

Parameters:

  • ​H (Hasher): The hasher type.

Args:

  • ​hasher (H): The hasher instance.