IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo package

shmem

Implements a subset of OpenSHMEM functionality.

It abstracts over both NVSHMEM and ROCSHMEM, exposing a similar API to DeviceContext with a symmetric heap that is accessible by inter-node and intra-node GPUs.

from std.testing import assert_equal
from shmem import shmem_my_pe, shmem_n_pes, shmem_p, SHMEMContext


def simple_shift_kernel(destination: UnsafePointer[Int32, _]):
    var mype = shmem_my_pe()
    var npes = shmem_n_pes()
    var peer = (mype + 1) % npes

    shmem_p(destination, mype, peer)


def main() raises:
    with SHMEMContext() as ctx:
        var destination = ctx.enqueue_create_buffer[DType.int32](1)
        ctx.enqueue_function[simple_shift_kernel](
            destination, grid_dim=1, block_dim=1
        )
        ctx.barrier_all()

        var msg = Int32(0)
        destination.enqueue_copy_to(UnsafePointer(to=msg))

        ctx.synchronize()

        print("PE:", ctx.my_pe(), "received message:", msg)

        assert_equal(msg, (ctx.my_pe() + 1) % ctx.n_pes())

Modules​