Mojo struct
MatmulConfig
@register_passable(trivial)
struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False, mma_shape: Index[3] = get_mma_shape[::DType,::DType,::Int]()]
Static configuration of GPU matmul.
Aliases
accum_type = get_accum_type[::DType,::DType]()
:split_k_reduction_scheme = env_get_int[::StringSlice[::Bool()
:OUTPUT_PRECISION = 2
:ACCUM_PRECISION = 1
:split_k_reduction_type = c_type if (env_get_int[::StringSlice[::Bool() == 2) else get_accum_type[::DType,::DType]()
:
Fields
- block_tile_shape (
Index[3]
): - warp_tile_shape (
Index[3]
): - num_pipeline_stages (
UInt
): - num_k_partitions (
UInt
): - k_group_size (
UInt
): - num_warp_k_partitions (
UInt
): - cluster_shape (
Index[3]
): - num_consumer (
UInt
): - partitioned_multicast (
Bool
): - scheduler_hint (
Index[3]
):
Implemented traits
AnyType
,
Copyable
,
ExplicitlyCopyable
,
Movable
,
Stringable
,
UnknownDestructibility
,
Writable
Methods
__init__
__init__(block_tile_shape: Index[3] = Index(128, 128, 32), warp_tile_shape: Index[3] = Index(64, 64, 32), cluster_shape: Index[3] = Index(1, 1, 1), num_pipeline_stages: UInt = UInt(4), num_k_partitions: UInt = UInt(1), k_group_size: UInt = UInt(1), num_warp_k_partitions: UInt = UInt(1), num_consumer: UInt = UInt(1), partitioned_multicast: Bool = False, scheduler_hint: Index[3] = Index(2, 2, 2), pdl_level: PDLLevel = PDLLevel()) -> Self
__eq__
__eq__(self, rhs: MatmulConfig[a_type, b_type, c_type, transpose_b, mma_shape]) -> Bool
num_warps_m
num_warps_m(self) -> UInt
num_warps_n
num_warps_n(self) -> UInt
num_threads
num_threads(self) -> UInt
shared_mem_usage
shared_mem_usage(self) -> Int
grid_dim
grid_dim(self, m: UInt, n: UInt) -> Index[3]
block_dim
block_dim(self) -> Index[3]
work_space_size
work_space_size(self, M: UInt, N: UInt) -> UInt
pdl_level
pdl_level(self) -> PDLLevel
__str__
__str__(self) -> String
write_to
write_to[W: Writer](self, mut writer: W)
__repr__
__repr__(self) -> String
__hash__
__hash__[H: _Hasher](self, mut hasher: H)
Updates hasher with the underlying bytes.
Parameters:
- H (
_Hasher
): The hasher type.
Args:
- hasher (
H
): The hasher instance.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!