IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo function

shmem_launch

shmem_launch[func: def(ctx: SHMEMContext[_]) raises -> None]()

Takes a function defining a SHMEM program and launches it on one thread for each GPU you have attached.

def simple_shift(ctx: SHMEMContext) raises:
    var destination = ctx.enqueue_create_buffer[DType.int32](1)

    ctx.enqueue_function[simple_shift_kernel](
        destination, grid_dim=1, block_dim=1
    )

    ctx.barrier_all()

    var msg = Int32(0)
    destination.enqueue_copy_to(UnsafePointer(to=msg))

    ctx.synchronize()

    var mype = shmem_my_pe()
    print("PE:", mype, "received message:", msg)
    assert_equal(msg, (mype + 1) % shmem_n_pes())

def main() raises:
    shmem_launch[simple_shift]()

This initializes SHMEM and runs the program in parallel across each attached GPU, taking care of initialization and cleanup logic. It will initialize and finalize MPI on the main thread if running on NVIDIA.

Any unhandled exceptions will abort with the device id and error message of the exception.

Parameters:

  • func (def(ctx: SHMEMContext[_]) raises -> None): The function to run once per attached GPU per node.

Raises:

If SHMEM initialization or the launched function fails.