Skip to main content

Mojo struct

MatmulKernels

@register_passable(trivial) struct MatmulKernels[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]

Supported matmul kernels.

The configurations are named as: . BK, mma shape, and warp tile shape are decided internally.

Implemented traits

AnyType, ExplicitlyCopyable, ImplicitlyCopyable, Movable, UnknownDestructibility

Aliases

__copyinit__is_trivial

alias __copyinit__is_trivial = True

__del__is_trivial

alias __del__is_trivial = True

__moveinit__is_trivial

alias __moveinit__is_trivial = True

ampere_128x128_4

alias ampere_128x128_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 128, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(4), UInt(1), UInt(1), UInt(1), UInt(1), False, PDLLevel())

ampere_256x128_3

alias ampere_256x128_3 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 256, (_bk_base[a_type]() * 2)), Index(64, 64, (_bk_base[a_type]() * 2)), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(3), UInt(1), UInt(1), UInt(1), UInt(1), False, PDLLevel())

ampere_256x64_4

alias ampere_256x64_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(64, 256, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(4), UInt(1), UInt(1), UInt(1), UInt(1), False, PDLLevel())

hopper_128x128_4

alias hopper_128x128_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 128, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(4), UInt(1), UInt(1), UInt(1), UInt(1), False, PDLLevel())

tuning_config

alias tuning_config = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(env_get_int["TUNE_BM", 128](), env_get_int["TUNE_BN", 128](), env_get_int["TUNE_BK", 32]()), Index(env_get_int["TUNE_WM", 64](), env_get_int["TUNE_WN", 64](), env_get_int["TUNE_BK", 32]()), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(env_get_int["TUNE_NUM_STAGES", 4]()), UInt(env_get_int["TUNE_NUM_K_PARTITIONS", 1]()), UInt(1), UInt(env_get_int["TUNE_NUM_WARP_K_PARTITIONS", 1]()), UInt(1), False, PDLLevel())

Was this page helpful?