Mojo struct
MatmulKernels
struct MatmulKernels[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]
Supported matmul kernels.
The configurations are named as:
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable,
RegisterPassable,
TrivialRegisterPassable
comptime members
ampere_128x128_4
comptime ampere_128x128_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](128, 128, _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](64, 64, _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(4), num_k_partitions=SIMD(1), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(1), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())
ampere_256x128_3
comptime ampere_256x128_3 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](128, 256, (2 * _bk_base[a_type]())), warp_tile_shape=Index[Int, Int, Int](64, 64, (2 * _bk_base[a_type]())), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(3), num_k_partitions=SIMD(1), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(1), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())
ampere_256x64_4
comptime ampere_256x64_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](64, 256, _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](64, 64, _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(4), num_k_partitions=SIMD(1), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(1), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())
hopper_128x128_4
comptime hopper_128x128_4 = MatmulConfig(block_tile_shape=Index[Int, Int, Int](128, 128, _bk_base[a_type]()), warp_tile_shape=Index[Int, Int, Int](64, 64, _bk_base[a_type]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(4), num_k_partitions=SIMD(1), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(1), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())
tuning_config
comptime tuning_config = MatmulConfig(block_tile_shape=Index[Int, Int, Int](get_defined_int[StringSlice("TUNE_BM"), 128](), get_defined_int[StringSlice("TUNE_BN"), 128](), get_defined_int[StringSlice("TUNE_BK"), 32]()), warp_tile_shape=Index[Int, Int, Int](get_defined_int[StringSlice("TUNE_WM"), 64](), get_defined_int[StringSlice("TUNE_WN"), 64](), get_defined_int[StringSlice("TUNE_BK"), 32]()), mma_shape=get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape=Index[Int, Int, Int](1, 1, 1), num_pipeline_stages=SIMD(get_defined_int[StringSlice("TUNE_NUM_STAGES"), 4]()), num_k_partitions=SIMD(get_defined_int[StringSlice("TUNE_NUM_K_PARTITIONS"), 1]()), k_group_size=SIMD(1), num_warp_k_partitions=SIMD(get_defined_int[StringSlice("TUNE_NUM_WARP_K_PARTITIONS"), 1]()), num_consumer=SIMD(1), partitioned_multicast=False, pdl_level=PDLLevel())
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!