Mojo struct
MatmulConfig
struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True]
Static configuration of SM90 GPU matmul.
Fieldsβ
- βblock_tile_shape (
IndexList[3]): - βmma_shape (
IndexList[3]): - βcluster_shape (
IndexList[3]): - βnum_pipeline_stages (
Int): - βnum_k_partitions (
Int): - βnum_consumer (
Int): - βpartitioned_multicast (
Bool): - βk_group_size (
Int):
Implemented traitsβ
AnyType,
Copyable,
Equatable,
Hashable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable,
Writable
Methodsβ
__init__β
__init__(block_tile_shape: IndexList[3], mma_shape: IndexList[3], cluster_shape: IndexList[3], num_pipeline_stages: Int, num_k_partitions: Int, num_consumer: Int, partitioned_multicast: Bool, pdl_level: PDLLevel, k_group_size: Int) -> Self
Initialize MatmulConfig with explicit values for all fields.
__init__(m: Int, n: Int, k: Int, num_k_partitions: Int = 1, partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel.OFF, k_groups: Optional[Int] = None, consumer_groups: Optional[Int] = None, swapAB: Bool = False) -> Self
Initialize MatmulConfig by computing optimal values from M, N, K.
Args:
- βm (
Int): The M dimension of the matmul. - βn (
Int): The N dimension of the matmul. - βk (
Int): The K dimension of the matmul. - βnum_k_partitions (
Int): Number of K partitions. - βpartitioned_multicast (
Bool): Whether to use partitioned multicast. - βpdl_level (
PDLLevel): PDL level for grid controls. - βk_groups (
Optional[Int]): How many pipeline (loads and stores) are grouped together. - βconsumer_groups (
Optional[Int]): The number of consumer groups. - βswapAB (
Bool): Whether to swap A and B.
__eq__β
adjust_kgroup_sizeβ
static adjust_kgroup_size(mma_m: Int, mma_n: Int, K: Int, BK: Int, num_pipeline_stages: Int) -> Int
Returns:
pdl_levelβ
to_base_configβ
to_base_config(self) -> MatmulConfig[a_type, b_type, c_type, transpose_b]
Convert to base MatmulConfig from utils_gpu.
Returns:
write_toβ
write_to(self, mut writer: T)
write_repr_toβ
write_repr_to(self, mut writer: T)
__hash__β
__hash__[H: Hasher](self, mut hasher: H)
Updates hasher with the underlying bytes.
Parameters:
- βH (
Hasher): The hasher type.
Args:
- βhasher (
H): The hasher instance.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!