Mojo function
load_AB
load_AB[a_type: DType, b_type: DType, a_scales_type: DType, a_rank: Int, a_tile_shape: IndexList[a_rank], a_desc_shape: IndexList[a_rank], b_rank: Int, b_tile_shape: IndexList[b_rank], b_desc_shape: IndexList[b_rank], a_scales_rank: Int, a_scales_tile_shape: IndexList[a_scales_rank], a_scales_desc_shape: IndexList[a_scales_rank], a_dim0: Int, a_dim1: Int, a_num_tiles: Int, a_swizzle_bytes: Int, b_dim0: Int, b_dim1: Int, b_num_tiles: Int, b_swizzle_bytes: Int, a_scales_smem_layout: Layout, num_pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](a_tma_op: TMATensorTile[a_type, a_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_rank, b_tile_shape, b_desc_shape], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_rank, a_scales_tile_shape, a_scales_desc_shape], a_smem_tiles: SMemTileArray2D[a_type, a_dim0, a_dim1, a_num_tiles, a_swizzle_bytes], b_smem_tiles: SMemTileArray2D[b_type, b_dim0, b_dim1, b_num_tiles, b_swizzle_bytes], a_scales_smem: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[num_pipeline_stages], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt, elect_one_cta: Bool)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!