Mojo struct
MatmulKernels
@register_passable(trivial)
struct MatmulKernels[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]
Supported matmul kernels.
The configurations are named as:
Implemented traits
AnyType
,
ExplicitlyCopyable
,
ImplicitlyCopyable
,
Movable
,
UnknownDestructibility
Aliases
__copyinit__is_trivial
alias __copyinit__is_trivial = True
__del__is_trivial
alias __del__is_trivial = True
__moveinit__is_trivial
alias __moveinit__is_trivial = True
ampere_128x128_4
alias ampere_128x128_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 128, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(4), UInt(1), UInt(1), UInt(1), UInt(1), False, PDLLevel())
ampere_256x128_3
alias ampere_256x128_3 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 256, (_bk_base[a_type]() * 2)), Index(64, 64, (_bk_base[a_type]() * 2)), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(3), UInt(1), UInt(1), UInt(1), UInt(1), False, PDLLevel())
ampere_256x64_4
alias ampere_256x64_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(64, 256, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(4), UInt(1), UInt(1), UInt(1), UInt(1), False, PDLLevel())
hopper_128x128_4
alias hopper_128x128_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 128, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(4), UInt(1), UInt(1), UInt(1), UInt(1), False, PDLLevel())
tuning_config
alias tuning_config = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(env_get_int["TUNE_BM", 128](), env_get_int["TUNE_BN", 128](), env_get_int["TUNE_BK", 32]()), Index(env_get_int["TUNE_WM", 64](), env_get_int["TUNE_WN", 64](), env_get_int["TUNE_BK", 32]()), get_mma_shape[a_type, get_accum_type[a_type]()](), Index(1, 1, 1), UInt(env_get_int["TUNE_NUM_STAGES", 4]()), UInt(env_get_int["TUNE_NUM_K_PARTITIONS", 1]()), UInt(1), UInt(env_get_int["TUNE_NUM_WARP_K_PARTITIONS", 1]()), UInt(1), False, PDLLevel())
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!