IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

B200BlockScaledMatmulSmem

struct B200BlockScaledMatmulSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]

Fields​

  • ​a_smem (InlineArray[Scalar[a_type], Int((mul config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], config.num_pipeline_stages))]):
  • ​b_smem (InlineArray[Scalar[b_type], Int((mul config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], config.num_pipeline_stages))]):
  • ​c_smem (InlineArray[Scalar[c_type], Int((mul config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)], config.num_output_stages))]):
  • ​sfa_smem (InlineArray[Scalar[sfa_dtype], Int((mul (config.block_tile_shape[Int(0)] // Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), config.num_pipeline_stages, config.num_sf_k_tiles, 4))]):
  • ​sfb_smem (InlineArray[Scalar[sfb_dtype], Int((mul (config.mma_shape[Int(1)] // Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), config.num_pipeline_stages, config.num_sf_k_tiles, 4))]):
  • ​tma_mma_mbars (InlineArray[SharedMemBarrier, ((config // config) * Int(2))]):
  • ​accum_mbars (InlineArray[SharedMemBarrier, (config * Int(2))]):
  • ​tmem_dealloc_mbar (InlineArray[SharedMemBarrier, Int(1)]):
  • ​tmem_addr (InlineArray[UInt32, Int(1)]):

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

a_smem_size​

comptime a_smem_size = (Int((mul config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)])) * config)

AScalesType​

comptime AScalesType = Scalar[sfa_dtype]

AType​

comptime AType = Scalar[a_type]

b_smem_size​

comptime b_smem_size = (Int((mul config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)])) * config)

BK​

comptime BK = config.block_tile_shape[Int(2)]

BM​

comptime BM = config.block_tile_shape[Int(0)]

BN​

comptime BN = config.block_tile_shape[Int(1)]

BScalesType​

comptime BScalesType = Scalar[sfb_dtype]

BType​

comptime BType = Scalar[b_type]

c_smem_size​

comptime c_smem_size = (Int((mul config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)])) * config)

CType​

comptime CType = Scalar[c_type]

MMA_K​

comptime MMA_K = config.mma_shape[Int(2)]

MMA_M​

comptime MMA_M = config.mma_shape[Int(0)]

MMA_N​

comptime MMA_N = config.mma_shape[Int(1)]

num_group_pipeline_stages​

comptime num_group_pipeline_stages = (config // config)

OutputM​

comptime OutputM = config.output_tile_shape[Int(0)]

OutputN​

comptime OutputN = config.output_tile_shape[Int(1)]

sfa_smem_size​

comptime sfa_smem_size = (Int((mul (config.block_tile_shape[Int(0)] // Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), config.num_sf_k_tiles, 4)) * config)

sfb_smem_size​

comptime sfb_smem_size = (Int((mul (config.mma_shape[Int(1)] // Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), config.num_sf_k_tiles, 4)) * config)