Skip to main content

Mojo function

load_AB

load_AB[a_type: DType, b_type: DType, a_scales_type: DType, a_tile_rank: Int, a_tile_shape: IndexList[a_tile_rank], a_desc_shape: IndexList[a_tile_rank], b_tile_rank: Int, b_tile_shape: IndexList[b_tile_rank], b_desc_shape: IndexList[b_tile_rank], a_scales_tile_rank: Int, a_scales_tile_shape: IndexList[a_scales_tile_rank], a_scales_desc_shape: IndexList[a_scales_tile_rank], num_pipeline_stages: UInt, expert_ids_layout: Layout, /, *, a_smem_layout: Layout, b_smem_layout: Layout, a_scales_smem_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](a_tma_op: TMATensorTile[a_type, a_tile_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_tile_rank, b_tile_shape, b_desc_shape], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_rank, a_scales_tile_shape, a_scales_desc_shape], a_smem_base: UnsafePointer[Scalar[a_type], MutAnyOrigin, address_space=AddressSpace.SHARED], b_smem_base: UnsafePointer[Scalar[b_type], MutAnyOrigin, address_space=AddressSpace.SHARED], a_scales_smem_base: UnsafePointer[Scalar[a_scales_type], MutAnyOrigin, address_space=AddressSpace.SHARED], load_mma_pipeline: ProducerConsumerPipeline[Int[UInt](num_pipeline_stages)], peer_cta_coord: Tuple[Int, Int, Int], work_tile_coord: Tuple[Int, Int], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt, elect_one_cta: Bool, scheduler: TileScheduler[static_MN=scheduler.static_MN, tile_shape=scheduler.tile_shape, cluster=scheduler.cluster, cta_group=scheduler.cta_group, swizzle=scheduler.swizzle, swapAB=scheduler.swapAB], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, ImmutAnyOrigin])

Was this page helpful?