Mojo function
load_AB
load_AB[a_type: DType, b_type: DType, a_tile_rank: Int, a_tile_shape: IndexList[a_tile_rank], a_desc_shape: IndexList[a_tile_rank], b_tile_rank: Int, b_tile_shape: IndexList[b_tile_rank], b_desc_shape: IndexList[b_tile_rank], num_pipeline_stages: UInt, /, *, a_smem_layout: Layout, b_smem_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](expert_ids: NDBuffer[DType.int32, ImmutAnyOrigin, DimList.create_unknown[1](), DimList.create_unknown[1]()], a_tma_op: TMATensorTile[a_type, a_tile_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tile_rank, b_tile_shape, b_desc_shape], a_smem_base: UnsafePointer[Scalar[a_type], MutAnyOrigin, address_space=AddressSpace.SHARED], b_smem_base: UnsafePointer[Scalar[b_type], MutAnyOrigin, address_space=AddressSpace.SHARED], mma_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], tma_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], producer_phase: PipelineState[Int.__init__[UInt](num_pipeline_stages)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool, scheduler: TileScheduler[static_MN=scheduler.static_MN, tile_shape=scheduler.tile_shape, cluster=scheduler.cluster, cta_group=scheduler.cta_group, swizzle=scheduler.swizzle, swapAB=scheduler.swapAB])
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!