Skip to main content

Mojo function

load_AB

load_AB[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_tile_rank: Int, a_tile_shape: IndexList[a_tile_rank], a_desc_shape: IndexList[a_tile_rank], b_tile_rank: Int, b_tile_shape: IndexList[b_tile_rank], b_desc_shape: IndexList[b_tile_rank], sfa_tile_rank: Int, sfa_tile_shape: IndexList[sfa_tile_rank], sfa_desc_shape: IndexList[sfa_tile_rank], sfb_tile_rank: Int, sfb_tile_shape: IndexList[sfb_tile_rank], sfb_desc_shape: IndexList[sfb_tile_rank], num_pipeline_stages: Int, group_scale_offsets_layout: Layout, transpose_b: Bool, /, *, a_smem_layout: Layout, b_smem_layout: Layout, sfa_smem_layout: Layout, sfb_smem_layout: Layout, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], block_tile_shape: IndexList[3], mma_shape: IndexList[3], num_sf_k_tiles: Int, cta_group: Int = 1, k_group_size: UInt = 1](a_tma_op: TMATensorTile[a_type, a_tile_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tile_rank, b_tile_shape, b_desc_shape], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_tile_rank, sfa_tile_shape, sfa_desc_shape], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_tile_rank, sfb_tile_shape, sfb_desc_shape], a_smem_base: UnsafePointer[Scalar[a_type], MutAnyOrigin, address_space=AddressSpace.SHARED], b_smem_base: UnsafePointer[Scalar[b_type], MutAnyOrigin, address_space=AddressSpace.SHARED], sfa_smem_base: UnsafePointer[Scalar[sfa_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], sfb_smem_base: UnsafePointer[Scalar[sfb_dtype], MutAnyOrigin, address_space=AddressSpace.SHARED], load_mma_pipeline: ProducerConsumerPipeline[num_pipeline_stages], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool, scheduler: TileScheduler[static_MN=scheduler.static_MN, tile_shape=scheduler.tile_shape, cluster=scheduler.cluster, cta_group=scheduler.cta_group, swizzle=scheduler.swizzle, swapAB=scheduler.swapAB], expert_id: Int32, group_scale_offsets: LayoutTensor[DType.uint32, group_scale_offsets_layout, MutAnyOrigin])

Was this page helpful?