IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

MatmulKernels

struct MatmulKernels[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]

Supported matmul kernels.

The configurations are named as: . BK, mma shape, and warp tile shape are decided internally.

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable

comptime members​

ampere_128x128_4​

comptime ampere_128x128_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](Int(128), Int(128), _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](Int(64), Int(64), _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=Int(4), num_k_partitions=Int(1), k_group_size=Int(1), num_warp_k_partitions=Int(1), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())

ampere_256x128_3​

comptime ampere_256x128_3 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](Int(128), Int(256), Int((mul _bk_base[a_type](), 2))), warp_tile_shape=Index[Int, Int, Int](Int(64), Int(64), Int((mul _bk_base[a_type](), 2))), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=Int(3), num_k_partitions=Int(1), k_group_size=Int(1), num_warp_k_partitions=Int(1), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())

ampere_256x64_4​

comptime ampere_256x64_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](Int(64), Int(256), _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](Int(64), Int(64), _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=Int(4), num_k_partitions=Int(1), k_group_size=Int(1), num_warp_k_partitions=Int(1), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())

hopper_128x128_4​

comptime hopper_128x128_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](Int(128), Int(128), _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](Int(64), Int(64), _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=Int(4), num_k_partitions=Int(1), k_group_size=Int(1), num_warp_k_partitions=Int(1), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())

tuning_config​

comptime tuning_config = MatmulConfig(block_tile_shape=Index[Int, Int, Int](get_defined_int[StringSlice("TUNE_BM"), Int(128)](), get_defined_int[StringSlice("TUNE_BN"), Int(128)](), get_defined_int[StringSlice("TUNE_BK"), Int(32)]()), warp_tile_shape=Index[Int, Int, Int](get_defined_int[StringSlice("TUNE_WM"), Int(64)](), get_defined_int[StringSlice("TUNE_WN"), Int(64)](), get_defined_int[StringSlice("TUNE_BK"), Int(32)]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=get_defined_int[StringSlice("TUNE_NUM_STAGES"), Int(4)](), num_k_partitions=get_defined_int[StringSlice("TUNE_NUM_K_PARTITIONS"), Int(1)](), k_group_size=Int(1), num_warp_k_partitions=get_defined_int[StringSlice("TUNE_NUM_WARP_K_PARTITIONS"), Int(1)](), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())