For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MatmulConfig
struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]
Static configuration of GPU matmul.
Fieldsβ
- βblock_tile_shape (
IndexList[Int(3)]): - βwarp_tile_shape (
IndexList[Int(3)]): - βmma_shape (
IndexList[Int(3)]): - βnum_pipeline_stages (
Int): - βnum_k_partitions (
Int): - βk_group_size (
Int): - βnum_warp_k_partitions (
Int): - βcluster_shape (
IndexList[Int(3)]): - βnum_consumer (
Int): - βpartitioned_multicast (
Bool):
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable,
Writable
comptime membersβ
ACCUM_PRECISIONβ
comptime ACCUM_PRECISION = 1
accum_typeβ
comptime accum_type = get_accum_type[a_type]()
OUTPUT_PRECISIONβ
comptime OUTPUT_PRECISION = 2
split_k_reduction_schemeβ
comptime split_k_reduction_scheme = get_defined_int[StringSlice("SPLITK_REDUCTION_SCHEME"), Int(2)]()
split_k_reduction_typeβ
comptime split_k_reduction_type = c_type if (Int(2) == get_defined_int[StringSlice("SPLITK_REDUCTION_SCHEME"), Int(2)]()) else MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type
Methodsβ
__init__β
def __init__(*, block_tile_shape: IndexList[Int(3)] = Index[Int, Int, Int](Int(128), Int(128), Int(32)), warp_tile_shape: IndexList[Int(3)] = Index[Int, Int, Int](Int(64), Int(64), Int(32)), mma_shape: IndexList[Int(3)] = get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape: IndexList[Int(3)] = Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages: Int = Int(4), num_k_partitions: Int = Int(1), k_group_size: Int = Int(1), num_warp_k_partitions: Int = Int(1), num_consumer: Int = Int(1), partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel()) -> Self
__eq__β
copy_fieldβ
def copy_field(mut self, other: MatmulConfig)
swapABβ
def swapAB(self) -> MatmulConfig[b_type, a_type, c_type, transpose_b]
Returns:
num_warps_mβ
num_warps_nβ
num_threadsβ
shared_mem_usageβ
grid_dimβ
block_dimβ
work_space_sizeβ
pdl_levelβ
write_toβ
def write_to(self, mut writer: T)
write_repr_toβ
def write_repr_to(self, mut writer: T)
__hash__β
def __hash__[H: Hasher](self, mut hasher: H)
Updates hasher with the underlying bytes.
Parameters:
- βH (
Hasher): The hasher type.
Args:
- βhasher (
H): The hasher instance.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!