Mojo function
cp_async_bulk_tensor_shared_cluster_global_multicast
cp_async_bulk_tensor_shared_cluster_global_multicast[dst_type: AnyType, mbr_type: AnyType, rank: Int](dst_mem: UnsafePointer[dst_type, address_space=AddressSpace(3)], tma_descriptor: UnsafePointer[NoneType], mem_bar: UnsafePointer[mbr_type, address_space=AddressSpace(3)], coords: Index[rank], multicast_mask: SIMD[uint16, 1])
Initiates an asynchronous multicast load operation using NVIDIA's Tensor Memory Access (TMA) to copy tensor data from global memory to shared memories of multiple CTAs in a cluster.
This function performs an optimized multicast copy operation where a single global memory read can be distributed to multiple CTAs' shared memories simultaneously, reducing memory bandwidth usage. It supports both rank-1 and rank-2 tensors and uses cluster-level synchronization.
Note: - This operation is asynchronous - use appropriate memory barriers to ensure copy completion - Only supports rank-1 and rank-2 tensors - Requires NVIDIA GPU with TMA support - The memory barrier should be properly initialized before use - The multicast_mask must be properly configured based on cluster size and desired distribution
Args:
- dst_mem (
UnsafePointer[dst_type, address_space=AddressSpace(3)]
): Pointer to the destination in shared memory where the tensor data will be copied. Must be properly aligned according to TMA requirements. - tma_descriptor (
UnsafePointer[NoneType]
): Pointer to the TMA descriptor containing metadata about tensor layout and memory access patterns. - mem_bar (
UnsafePointer[mbr_type, address_space=AddressSpace(3)]
): Pointer to a shared memory barrier used for synchronizing the asynchronous copy operation across threads in the cluster. - coords (
Index[rank]
): Coordinates specifying which tile of the tensor to copy. For rank-1 tensors, this is a single coordinate. For rank-2 tensors, this contains both row and column coordinates. - multicast_mask (
SIMD[uint16, 1]
): A 16-bit bitmask where each bit corresponds to a CTA in the cluster. Set bits indicate which CTAs will receive a copy of the loaded data. This enables efficient data sharing across multiple CTAs.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!