Mojo struct
MatmulConfig
struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]
Static configuration of GPU matmul.
Fieldsβ
- βblock_tile_shape (
IndexList[3]): - βwarp_tile_shape (
IndexList[3]): - βmma_shape (
IndexList[3]): - βnum_pipeline_stages (
Int): - βnum_k_partitions (
Int): - βk_group_size (
Int): - βnum_warp_k_partitions (
Int): - βcluster_shape (
IndexList[3]): - βnum_consumer (
Int): - βpartitioned_multicast (
Bool):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable,
Writable
comptime membersβ
ACCUM_PRECISIONβ
comptime ACCUM_PRECISION = 1
accum_typeβ
comptime accum_type = get_accum_type[a_type]()
OUTPUT_PRECISIONβ
comptime OUTPUT_PRECISION = 2
split_k_reduction_schemeβ
comptime split_k_reduction_scheme = get_defined_int[StringSlice("SPLITK_REDUCTION_SCHEME"), 2]()
split_k_reduction_typeβ
comptime split_k_reduction_type = c_type if (2 == MatmulConfig[a_type, b_type, c_type, transpose_b].split_k_reduction_scheme) else MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type
Methodsβ
__init__β
__init__(*, block_tile_shape: IndexList[3] = Index[Int, Int, Int](128, 128, 32), warp_tile_shape: IndexList[3] = Index[Int, Int, Int](64, 64, 32), mma_shape: IndexList[3] = get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index[Int, Int, Int](1, 1, 1), num_pipeline_stages: Int = 4, num_k_partitions: Int = 1, k_group_size: Int = 1, num_warp_k_partitions: Int = 1, num_consumer: Int = 1, partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel()) -> Self
__eq__β
copy_fieldβ
copy_field(mut self, other: MatmulConfig)
swapABβ
swapAB(self) -> MatmulConfig[b_type, a_type, c_type, transpose_b]
Returns:
num_warps_mβ
num_warps_nβ
num_threadsβ
shared_mem_usageβ
grid_dimβ
block_dimβ
work_space_sizeβ
pdl_levelβ
write_toβ
write_to(self, mut writer: T)
write_repr_toβ
write_repr_to(self, mut writer: T)
__hash__β
__hash__[H: Hasher](self, mut hasher: H)
Updates hasher with the underlying bytes.
Parameters:
- βH (
Hasher): The hasher type.
Args:
- βhasher (
H): The hasher instance.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!