Skip to main content
Log in

Mojo struct

MatmulConfig

@register_passable(trivial) struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False, mma_shape: Index[3] = get_mma_shape[::DType,::DType,::Int]()]

Static configuration of GPU matmul.

Aliases

  • accum_type = get_accum_type[::DType,::DType]():
  • split_k_reduction_scheme = env_get_int[::StringSlice[::Bool():
  • OUTPUT_PRECISION = 2:
  • ACCUM_PRECISION = 1:
  • split_k_reduction_type = c_type if (env_get_int[::StringSlice[::Bool() == 2) else get_accum_type[::DType,::DType]():

Fields

  • block_tile_shape (Index[3]):
  • warp_tile_shape (Index[3]):
  • num_pipeline_stages (UInt):
  • num_k_partitions (UInt):
  • k_group_size (UInt):
  • num_warp_k_partitions (UInt):
  • cluster_shape (Index[3]):
  • num_consumer (UInt):
  • partitioned_multicast (Bool):
  • scheduler_hint (Index[3]):

Implemented traits

AnyType, Copyable, ExplicitlyCopyable, Movable, Stringable, UnknownDestructibility, Writable

Methods

__init__

__init__(block_tile_shape: Index[3] = Index(128, 128, 32), warp_tile_shape: Index[3] = Index(64, 64, 32), cluster_shape: Index[3] = Index(1, 1, 1), num_pipeline_stages: UInt = UInt(4), num_k_partitions: UInt = UInt(1), k_group_size: UInt = UInt(1), num_warp_k_partitions: UInt = UInt(1), num_consumer: UInt = UInt(1), partitioned_multicast: Bool = False, scheduler_hint: Index[3] = Index(2, 2, 2), pdl_level: PDLLevel = PDLLevel()) -> Self

__eq__

__eq__(self, rhs: MatmulConfig[a_type, b_type, c_type, transpose_b, mma_shape]) -> Bool

num_warps_m

num_warps_m(self) -> UInt

num_warps_n

num_warps_n(self) -> UInt

num_threads

num_threads(self) -> UInt

shared_mem_usage

shared_mem_usage(self) -> Int

grid_dim

grid_dim(self, m: UInt, n: UInt) -> Index[3]

block_dim

block_dim(self) -> Index[3]

work_space_size

work_space_size(self, M: UInt, N: UInt) -> UInt

pdl_level

pdl_level(self) -> PDLLevel

__str__

__str__(self) -> String

write_to

write_to[W: Writer](self, mut writer: W)

__repr__

__repr__(self) -> String

__hash__

__hash__[H: _Hasher](self, mut hasher: H)

Updates hasher with the underlying bytes.

Parameters:

  • H (_Hasher): The hasher type.

Args:

  • hasher (H): The hasher instance.

Was this page helpful?