Skip to main content
Log in

Mojo function

cp_async_bulk_tensor_shared_cluster_global_multicast

cp_async_bulk_tensor_shared_cluster_global_multicast[dst_type: AnyType, mbr_type: AnyType, rank: Int](dst_mem: UnsafePointer[dst_type, address_space=AddressSpace(3)], tma_descriptor: UnsafePointer[NoneType], mem_bar: UnsafePointer[mbr_type, address_space=AddressSpace(3)], coords: Index[rank], multicast_mask: SIMD[uint16, 1])

Initiates an asynchronous multicast load operation using NVIDIA's Tensor Memory Access (TMA) to copy tensor data from global memory to shared memories of multiple CTAs in a cluster.

This function performs an optimized multicast copy operation where a single global memory read can be distributed to multiple CTAs' shared memories simultaneously, reducing memory bandwidth usage. It supports both rank-1 and rank-2 tensors and uses cluster-level synchronization.

Note: - This operation is asynchronous - use appropriate memory barriers to ensure copy completion - Only supports rank-1 and rank-2 tensors - Requires NVIDIA GPU with TMA support - The memory barrier should be properly initialized before use - The multicast_mask must be properly configured based on cluster size and desired distribution

Args:

  • dst_mem (UnsafePointer[dst_type, address_space=AddressSpace(3)]): Pointer to the destination in shared memory where the tensor data will be copied. Must be properly aligned according to TMA requirements.
  • tma_descriptor (UnsafePointer[NoneType]): Pointer to the TMA descriptor containing metadata about tensor layout and memory access patterns.
  • mem_bar (UnsafePointer[mbr_type, address_space=AddressSpace(3)]): Pointer to a shared memory barrier used for synchronizing the asynchronous copy operation across threads in the cluster.
  • coords (Index[rank]): Coordinates specifying which tile of the tensor to copy. For rank-1 tensors, this is a single coordinate. For rank-2 tensors, this contains both row and column coordinates.
  • multicast_mask (SIMD[uint16, 1]): A 16-bit bitmask where each bit corresponds to a CTA in the cluster. Set bits indicate which CTAs will receive a copy of the loaded data. This enables efficient data sharing across multiple CTAs.