Mojo struct
SharedMemBarrier
@register_passable(trivial)
struct SharedMemBarrier
A hardware-accelerated synchronization primitive for GPU shared memory operations.
This struct provides a barrier mechanism optimized for coordinating thread execution and memory transfers in GPU kernels, particularly for Tensor Memory Accelerator (TMA) operations. It enables efficient synchronization between threads and memory operations by leveraging hardware-specific barrier instructions.
Key features:
- Thread synchronization across thread blocks
- Memory transfer completion tracking
- Hardware-accelerated barrier operations
- Support for phased synchronization
This barrier is particularly useful for ensuring that shared memory operations complete before dependent computations begin, which is critical for maintaining data consistency in high-performance GPU kernels.
Fields
- mbar (
SIMD[int64, 1]
): Shared memory location used for the barrier state.
Implemented traits
AnyType
,
CollectionElement
,
Copyable
,
ExplicitlyCopyable
,
Movable
,
UnknownDestructibility
Methods
init
init(ref [3] self, num_threads: SIMD[int32, 1] = __init__[__mlir_type.!pop.int_literal](1))
Initialize the barrier state with the expected number of threads.
Sets up the barrier to expect arrivals from the specified number of threads before it can be satisfied. This is essential for coordinating thread synchronization in GPU kernels.
Args:
- num_threads (
SIMD[int32, 1]
): Number of threads that must arrive at the barrier before it is satisfied. Defaults to 1.
expect_bytes
expect_bytes(ref [3] self, bytes: SIMD[int32, 1])
Configure the barrier to expect a specific number of bytes to be transferred.
Used with TMA operations to indicate the expected size of data transfer. The barrier will be satisfied when the specified number of bytes has been transferred, enabling efficient coordination of memory operations.
Args:
- bytes (
SIMD[int32, 1]
): Number of bytes expected to be transferred.
wait
wait(ref [3] self, phase: SIMD[uint32, 1] = __init__[__mlir_type.!pop.int_literal](0))
Wait until the barrier is satisfied.
Blocks the calling thread until the barrier is satisfied, either by the expected number of threads arriving or the expected data transfer completing. This method implements an efficient spin-wait mechanism optimized for GPU execution.
Note: Minimizes thread divergence during synchronization by using hardware-accelerated barrier instructions.
Args:
- phase (
SIMD[uint32, 1]
): The phase value to check against. Defaults to 0.
unsafe_ptr
unsafe_ptr(ref [3] self) -> UnsafePointer[SIMD[int64, 1], address_space=AddressSpace(3), alignment=8, mut=self_is_mut, origin=self_is_origin]
Get an unsafe pointer to the barrier's memory location.
Provides low-level access to the shared memory location storing the barrier state. This method is primarily used internally by other barrier operations that need direct access to the underlying memory.
Returns:
An unsafe pointer to the barrier's memory location in shared memory, properly typed and aligned for barrier operations.
arrive_cluster
arrive_cluster(ref [3] self, cta_id: SIMD[uint32, 1], count: SIMD[uint32, 1] = __init__[__mlir_type.!pop.int_literal](1))
Signal arrival at the barrier from a specific CTA (Cooperative Thread Array) in a cluster.
This method is used in multi-CTA scenarios to coordinate barrier arrivals across different CTAs within a cluster. It enables efficient synchronization across thread blocks in clustered execution models.
Args:
- cta_id (
SIMD[uint32, 1]
): The ID of the CTA (Cooperative Thread Array) that is arriving. - count (
SIMD[uint32, 1]
): The number of arrivals to signal. Defaults to 1.
arrive
arrive(ref [3] self) -> Int
Signal arrival at the barrier and return the arrival count.
This method increments the arrival count at the barrier and returns the updated count. It's used to track how many threads have reached the synchronization point.
Returns:
The updated arrival count after this thread's arrival.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!