Skip to main content

Mojo function

load_AB

load_AB[a_type: DType, b_type: DType, a_tile_rank: Int, a_tile_shape: IndexList[a_tile_rank], a_desc_shape: IndexList[a_tile_rank], b_tile_rank: Int, b_tile_shape: IndexList[b_tile_rank], b_desc_shape: IndexList[b_tile_rank], a_dim0: Int, a_dim1: Int, a_num_tiles: Int, a_swizzle_bytes: Int, b_dim0: Int, b_dim1: Int, b_num_tiles: Int, b_swizzle_bytes: Int, num_pipeline_stages: UInt, //, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](expert_ids: UnsafePointer[Int32, ImmutAnyOrigin], a_tma_op: TMATensorTile[a_type, a_tile_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tile_rank, b_tile_shape, b_desc_shape], a_smem_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, a_num_tiles, a_swizzle_bytes], b_smem_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, b_num_tiles, b_swizzle_bytes], mma_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], tma_mbar: UnsafePointer[SharedMemBarrier, MutAnyOrigin, address_space=AddressSpace.SHARED], producer_phase: PipelineState[Int[UInt](num_pipeline_stages)], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool, scheduler: TileScheduler[static_MN=scheduler.static_MN, tile_shape=scheduler.tile_shape, cluster=scheduler.cluster, cta_group=scheduler.cta_group, swizzle=scheduler.swizzle, swapAB=scheduler.swapAB])

Was this page helpful?