Mojo function
load_AB_SFA
load_AB_SFA[a_type: DType, b_type: DType, sfa_dtype: DType, a_rank: Int, a_tile_shape: IndexList[a_rank], a_desc_shape: IndexList[a_rank], b_rank: Int, b_tile_shape: IndexList[b_rank], b_desc_shape: IndexList[b_rank], sfa_rank: Int, sfa_tile_shape: IndexList[sfa_rank], sfa_desc_shape: IndexList[sfa_rank], a_smem_layout: Layout, b_smem_layout: Layout, sfa_smem_layout: Layout, num_pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], num_sf_k_tiles: Int, cta_group: Int = 1, k_group_size: UInt = 1](a_tma_op: TMATensorTile[a_type, a_rank, a_tile_shape, a_desc_shape], b_tma_op: TMATensorTile[b_type, b_rank, b_tile_shape, b_desc_shape], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_rank, sfa_tile_shape, sfa_desc_shape], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfa_smem: LayoutTensorIter[sfa_dtype, sfa_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[num_pipeline_stages], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!