For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
MatmulKernels
struct MatmulKernels[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]
Supported matmul kernels.
The configurations are named as:
Implemented traitsβ
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDeletable,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime membersβ
ampere_128x128_4β
comptime ampere_128x128_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](Int(128), Int(128), _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](Int(64), Int(64), _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=Int(4), num_k_partitions=Int(1), k_group_size=Int(1), num_warp_k_partitions=Int(1), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())
ampere_256x128_3β
comptime ampere_256x128_3 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](Int(128), Int(256), Int((mul _bk_base[a_type](), 2))), warp_tile_shape=Index[Int, Int, Int](Int(64), Int(64), Int((mul _bk_base[a_type](), 2))), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=Int(3), num_k_partitions=Int(1), k_group_size=Int(1), num_warp_k_partitions=Int(1), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())
ampere_256x64_4β
comptime ampere_256x64_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](Int(64), Int(256), _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](Int(64), Int(64), _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=Int(4), num_k_partitions=Int(1), k_group_size=Int(1), num_warp_k_partitions=Int(1), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())
hopper_128x128_4β
comptime hopper_128x128_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](Int(128), Int(128), _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](Int(64), Int(64), _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=Int(4), num_k_partitions=Int(1), k_group_size=Int(1), num_warp_k_partitions=Int(1), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())
tuning_configβ
comptime tuning_config = MatmulConfig(block_tile_shape=Index[Int, Int, Int](get_defined_int[StringSlice("TUNE_BM"), Int(128)](), get_defined_int[StringSlice("TUNE_BN"), Int(128)](), get_defined_int[StringSlice("TUNE_BK"), Int(32)]()), warp_tile_shape=Index[Int, Int, Int](get_defined_int[StringSlice("TUNE_WM"), Int(64)](), get_defined_int[StringSlice("TUNE_WN"), Int(64)](), get_defined_int[StringSlice("TUNE_BK"), Int(32)]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](Int(1), Int(1), Int(1)), num_pipeline_stages=get_defined_int[StringSlice("TUNE_NUM_STAGES"), Int(4)](), num_k_partitions=get_defined_int[StringSlice("TUNE_NUM_K_PARTITIONS"), Int(1)](), k_group_size=Int(1), num_warp_k_partitions=get_defined_int[StringSlice("TUNE_NUM_WARP_K_PARTITIONS"), Int(1)](), num_consumer=Int(1), partitioned_multicast=False, pdl_level=PDLLevel())
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!