For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo struct
B200BlockScaledMatmulSmem
struct B200BlockScaledMatmulSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]
Fieldsβ
- βa_smem (
InlineArray[Scalar[a_type], Int((mul config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)], config.num_pipeline_stages))]): - βb_smem (
InlineArray[Scalar[b_type], Int((mul config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)], config.num_pipeline_stages))]): - βc_smem (
InlineArray[Scalar[c_type], Int((mul config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)], config.num_output_stages))]): - βsfa_smem (
InlineArray[Scalar[sfa_dtype], Int((mul (config.block_tile_shape[Int(0)] // Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), config.num_pipeline_stages, config.num_sf_k_tiles, 4))]): - βsfb_smem (
InlineArray[Scalar[sfb_dtype], Int((mul (align_up(config.mma_shape[Int(1)], Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))) // Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), config.num_pipeline_stages, config.num_sf_k_tiles, 4))]): - βtma_mma_mbars (
InlineArray[SharedMemBarrier, ((config // config) * Int(2))]): - βaccum_mbars (
InlineArray[SharedMemBarrier, (config * Int(2))]): - βclc_mbars_full (
InlineArray[SharedMemBarrier, config.num_clc_pipeline_stages]): - βclc_mbars_empty (
InlineArray[SharedMemBarrier, config.num_clc_pipeline_stages]): - βclc_throttle_mbars (
InlineArray[SharedMemBarrier, (config * Int(2))]): - βclc_response (
InlineArray[UInt128, config.num_clc_pipeline_stages]): - βtmem_dealloc_mbar (
InlineArray[SharedMemBarrier, Int(1)]): - βtmem_addr (
InlineArray[UInt32, Int(1)]):
Implemented traitsβ
comptime membersβ
a_smem_sizeβ
comptime a_smem_size = (Int((mul config.block_tile_shape[Int(0)], config.block_tile_shape[Int(2)])) * config)
AScalesTypeβ
comptime AScalesType = Scalar[sfa_dtype]
ATypeβ
comptime AType = Scalar[a_type]
b_smem_sizeβ
comptime b_smem_size = (Int((mul config.block_tile_shape[Int(1)], config.block_tile_shape[Int(2)])) * config)
BKβ
comptime BK = config.block_tile_shape[Int(2)]
BMβ
comptime BM = config.block_tile_shape[Int(0)]
BNβ
comptime BN = config.block_tile_shape[Int(1)]
BScalesTypeβ
comptime BScalesType = Scalar[sfb_dtype]
BTypeβ
comptime BType = Scalar[b_type]
c_smem_sizeβ
comptime c_smem_size = (Int((mul config.output_tile_shape[Int(0)], config.output_tile_shape[Int(1)])) * config)
CTypeβ
comptime CType = Scalar[c_type]
MMA_Kβ
comptime MMA_K = config.mma_shape[Int(2)]
MMA_Mβ
comptime MMA_M = config.mma_shape[Int(0)]
MMA_Nβ
comptime MMA_N = config.mma_shape[Int(1)]
num_group_pipeline_stagesβ
comptime num_group_pipeline_stages = (config // config)
output_mβ
comptime output_m = config.output_tile_shape[Int(0)]
output_nβ
comptime output_n = config.output_tile_shape[Int(1)]
sfa_smem_sizeβ
comptime sfa_smem_size = (Int((mul (config.block_tile_shape[Int(0)] // Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), config.num_sf_k_tiles, 4)) * config)
sfb_smem_sizeβ
comptime sfb_smem_size = (Int((mul (align_up(config.mma_shape[Int(1)], Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))) // Int((mul (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]())))), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(0)]()), (load_from_mem Tuple(Int(32), Int(4)).__getitem_param__[Int(1)]()), config.num_sf_k_tiles, 4)) * config)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!