Skip to main content

Mojo struct

MatmulKernels

struct MatmulKernels[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]

Supported matmul kernels.

The configurations are named as: . BK, mma shape, and warp tile shape are decided internally.

Implemented traits

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable, RegisterPassable, TrivialRegisterPassable

comptime members

ampere_128x128_4

comptime ampere_128x128_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](128, 128, _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](64, 64, _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(4), num_k_partitions=SIMD(1), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(1), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())

ampere_256x128_3

comptime ampere_256x128_3 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](128, 256, (2 * _bk_base[a_type]())), warp_tile_shape=Index[Int, Int, Int](64, 64, (2 * _bk_base[a_type]())), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(3), num_k_partitions=SIMD(1), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(1), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())

ampere_256x64_4

comptime ampere_256x64_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](64, 256, _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](64, 64, _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(4), num_k_partitions=SIMD(1), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(1), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())

hopper_128x128_4

comptime hopper_128x128_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](128, 128, _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](64, 64, _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(4), num_k_partitions=SIMD(1), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(1), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())

tuning_config

comptime tuning_config = MatmulConfig(block_tile_shape=Index[Int, Int, Int](get_defined_int[StringSlice("TUNE_BM"), 128](), get_defined_int[StringSlice("TUNE_BN"), 128](), get_defined_int[StringSlice("TUNE_BK"), 32]()), warp_tile_shape=Index[Int, Int, Int](get_defined_int[StringSlice("TUNE_WM"), 64](), get_defined_int[StringSlice("TUNE_WN"), 64](), get_defined_int[StringSlice("TUNE_BK"), 32]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(get_defined_int[StringSlice("TUNE_NUM_STAGES"), 4]()), num_k_partitions=SIMD(get_defined_int[StringSlice("TUNE_NUM_K_PARTITIONS"), 1]()), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(get_defined_int[StringSlice("TUNE_NUM_WARP_K_PARTITIONS"), 1]()), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())

Was this page helpful?