IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

BlockScaledMatmulConfig

struct BlockScaledMatmulConfig[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool = True]

Static configuration of GPU matmul.

Fields​

  • ​cta_group (Int):
  • ​mma_shape (IndexList[3]):
  • ​cluster_shape (IndexList[3]):
  • ​AB_swapped (Bool):
  • ​block_swizzle_size (Int):
  • ​raster_order (RasterOrder):
  • ​register_based_epilogue (Bool):
  • ​block_tile_shape (IndexList[3]):
  • ​num_split_k (Int):
  • ​num_pipeline_stages (Int):
  • ​num_clc_pipeline_stages (Int):
  • ​num_accum_pipeline_stages (Int):
  • ​num_output_stages (Int):
  • ​output_tile_shape (IndexList[2]):
  • ​a_swizzle (TensorMapSwizzle):
  • ​b_swizzle (TensorMapSwizzle):
  • ​c_swizzle (TensorMapSwizzle):
  • ​k_group_size (Int):
  • ​scaling_kind (UMMAKind):
  • ​vec_sf_size (Int):
  • ​num_sf_k_tiles (Int):
  • ​is_small_bn (Bool):
  • ​gemm_kind (GEMMKind):

Implemented traits​

AnyType, Copyable, Equatable, Hashable, ImplicitlyCopyable, ImplicitlyDeletable, Movable, RegisterPassable, TrivialRegisterPassable, Writable

comptime members​

accum_type​

comptime accum_type = get_accum_type[a_type]()

sf_block_atom_size​

comptime sf_block_atom_size = (((load_from_mem SF_ATOM_M.__getitem_param__[0]()) * (load_from_mem SF_ATOM_M.__getitem_param__[1]())) * 4)

Methods​

__init__​

def __init__(*, scaling_kind: UMMAKind, cta_group: Int = 2, mma_shape: IndexList[3] = get_mma_shape[a_type, BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index[Int, Int, Int](2, 1, 1), AB_swapped: Bool = False, num_split_k: Int = 1, block_swizzle_size: Int = 0, raster_order: RasterOrder = RasterOrder.AlongM, k_group_size: Int = 1, num_pipeline_stages: Optional[Int] = None, num_accum_pipeline_stages: Int = 2, num_clc_pipeline_stages: Int = 2, is_gmm: Bool = False, is_small_bn: Bool = False, register_based_epilogue: Bool = True, gemm_kind: GEMMKind = GEMMKind.GEMM) -> Self

swap_AB_type​

def swap_AB_type(self) -> BlockScaledMatmulConfig[b_type, a_type, c_type, sfb_dtype, sfa_dtype, transpose_b]

Returns:

BlockScaledMatmulConfig[b_type, a_type, c_type, sfb_dtype, sfa_dtype, transpose_b]

write_to​

def write_to[W: Writer](self, mut writer: W)

write_repr_to​

def write_repr_to(self, mut writer: T)

get_kernel_name​

def get_kernel_name(self) -> String

Returns:

String