IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

SHMEMContext

struct SHMEMContext[tcp: Bool = False]

Usable as a context manager to run kernels on a GPU with SHMEM support, on exit it will finalize SHMEM and clean up resources.

Example:

from shmem import SHMEMContext

with SHMEMContext() as ctx:
    ctx.enqueue_function[kernel](grid_dim=1, block_dim=1)

Implemented traits​

AnyType, Copyable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

Methods​

__init__​

__init__(out self, team: Int32 = Int32(2)) where (tcp == False)

Initializes a device context with SHMEM support.

This constructor initializes MPI and SHMEM, and creates a device context for the current PE's assigned GPU device.

Warning: if you're not using this as a context manager, you must call SHMEMContext.finalize() manually.

Raises:

If initialization fails.

__init__(out self, ctx: DeviceContext) where (tcp == False)

Initializes a device context with SHMEM support, using one thread per GPU.

This constructor expects that MPI has already been initialized in the main thread, it then initializes SHMEM, and creates a device context for the associated PE on this node.

Warning: if you're not using this as a context manager, you must call SHMEMContext.finalize() manually.

Raises:

If initialization fails.

__init__(out self, ctx: DeviceContext, node_id: Int = -1, total_nodes: Int = -1, gpus_per_node: Int = -1, server_ip: String = "-1", server_port: Int = -1) where tcp

Initializes a device context with SHMEM support, using one thread per GPU and TCP bootstrapping with a unique ID.

Warning: if you're not using this as a context manager, you must call SHMEMContext.finalize() manually.

Raises:

If initialization fails.

__del__​

__del__(deinit self)

Context manager exit method.

Automatically finalizes SHMEM when exiting the context.

__enter__​

__enter__(var self) -> Self

Context manager entry method.

Returns:

Self: Self for use in with statements.

finalize​

finalize(mut self)

Finalizes the SHMEM runtime environment.

Cleans up SHMEM and MPI resources.

Raises:

If SHMEM or MPI finalization fails.

barrier_all​

barrier_all(self)

Performs a barrier synchronization across all PEs.

All PEs must call this function before any PE can proceed past the barrier.

Raises:

If the barrier operation fails.

enqueue_create_buffer​

enqueue_create_buffer[dtype: DType](self, size: Int) -> SHMEMBuffer[dtype]

Creates a SHMEM buffer that can be accessed by all PEs.

Parameters:

  • ​dtype (DType): The data type of elements in the buffer.

Args:

  • ​size (Int): Number of elements in the buffer.

Returns:

SHMEMBuffer[dtype]: A SHMEMBuffer instance for the allocated memory.

Raises:

String: If buffer creation fails.

enqueue_function​

enqueue_function[declared_arg_types: TypeList[declared_arg_types.values], //, func: def(*args: *declared_arg_types) -> None, *actual_arg_types: DevicePassable, *, dump_asm: Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path] = False, _dump_sass: Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *args: *actual_arg_types.values, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List(__list_literal__=NoneType(None)), var constant_memory: List[ConstantMemoryMapping] = List(__list_literal__=NoneType(None)), func_attribute: OptionalReg[FuncAttribute] = None)

Compiles and enqueues a kernel for execution on this device.

You can pass the function directly to enqueue_function without compiling it first:

from shmem import SHMEMContext

def kernel():
    print("hello from the GPU")

with SHMEMContext() as ctx:
    ctx.enqueue_function[kernel](grid_dim=1, block_dim=1)
    ctx.synchronize()

Parameters:

  • ​declared_arg_types (TypeList[declared_arg_types.values]): The declared argument types from the function signature (usually inferred).
  • ​func (def(*args: *declared_arg_types) -> None): The function to launch.
  • ​*actual_arg_types (DevicePassable): The types of the arguments being passed (usually inferred).
  • ​dump_asm (Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path]): To dump the compiled assembly, pass True, or a file path to dump to, or a function returning a file path.
  • ​dump_llvm (Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path]): To dump the generated LLVM code, pass True, or a file path to dump to, or a function returning a file path.
  • ​_dump_sass (Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path]): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass True, or a file path to dump to, or a function returning a file path.
  • ​_ptxas_info_verbose (Bool): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes dump_asm to output verbose PTX assembly (default False).

Args:

  • ​*args (*actual_arg_types.values): Variadic arguments which are passed to the func.
  • ​grid_dim (Dim): The grid dimensions.
  • ​block_dim (Dim): The block dimensions.
  • ​cluster_dim (OptionalReg[Dim]): The cluster dimensions.
  • ​shared_mem_bytes (OptionalReg[Int]): Per-block memory shared between blocks.
  • ​attributes (List[LaunchAttribute]): A List of launch attributes.
  • ​constant_memory (List[ConstantMemoryMapping]): A List of constant memory mappings.
  • ​func_attribute (OptionalReg[FuncAttribute]): CUfunction_attribute enum.

enqueue_function_collective_checked​

enqueue_function_collective_checked[declared_arg_types: TypeList[declared_arg_types.values], //, func: def(*args: *declared_arg_types) -> None, *actual_arg_types: DevicePassable, *, dump_asm: Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path] = False, _dump_sass: Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *args: *actual_arg_types.values, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List(__list_literal__=NoneType(None)), var constant_memory: List[ConstantMemoryMapping] = List(__list_literal__=NoneType(None)), func_attribute: OptionalReg[FuncAttribute] = None)

Compiles and enqueues a kernel for execution on this device.

You can pass the function directly to enqueue_function without compiling it first:

from std.gpu.host import DeviceContext

def kernel():
    print("hello from the GPU")

with DeviceContext() as ctx:
    ctx.enqueue_function[kernel](grid_dim=1, block_dim=1)
    ctx.synchronize()

If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile it first to remove the overhead:

with DeviceContext() as ctx:
    ctx.enqueue_function[kernel](grid_dim=1, block_dim=1)
    ctx.enqueue_function[kernel](grid_dim=1, block_dim=1)
    ctx.synchronize()

Parameters:

  • ​declared_arg_types (TypeList[declared_arg_types.values]): The declared argument types from the function signature (usually inferred).
  • ​func (def(*args: *declared_arg_types) -> None): The function to launch.
  • ​*actual_arg_types (DevicePassable): The types of the arguments being passed (usually inferred).
  • ​dump_asm (Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path]): To dump the compiled assembly, pass True, or a file path to dump to, or a function returning a file path.
  • ​dump_llvm (Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path]): To dump the generated LLVM code, pass True, or a file path to dump to, or a function returning a file path.
  • ​_dump_sass (Variant[Bool, Path, StringSlice[StaticConstantOrigin], def() capturing -> Path]): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass True, or a file path to dump to, or a function returning a file path.
  • ​_ptxas_info_verbose (Bool): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes dump_asm to output verbose PTX assembly (default False).

Args:

  • ​*args (*actual_arg_types.values): Variadic arguments which are passed to the func.
  • ​grid_dim (Dim): The grid dimensions.
  • ​block_dim (Dim): The block dimensions.
  • ​cluster_dim (OptionalReg[Dim]): The cluster dimensions.
  • ​shared_mem_bytes (OptionalReg[Int]): Per-block memory shared between blocks.
  • ​attributes (List[LaunchAttribute]): A List of launch attributes.
  • ​constant_memory (List[ConstantMemoryMapping]): A List of constant memory mappings.
  • ​func_attribute (OptionalReg[FuncAttribute]): CUfunction_attribute enum.

synchronize​

synchronize(self)

Blocks until all asynchronous calls on the stream associated with this device context have completed.

Raises:

If synchronization fails.

get_device_context​

get_device_context(self) -> DeviceContext

Returns the device context associated with this SHMEMContext.

Returns:

DeviceContext: The device context associated with this SHMEMContext.

number_of_devices​

static number_of_devices(*, api: String = DeviceContext.default_device_info.api) -> Int

Returns the number of devices available that support the specified API.

This function queries the system for available devices that support the requested API (such as CUDA or HIP). It's useful for determining how many accelerators are available before allocating resources or distributing work.

Args:

  • ​api (String): Requested device API (for example, "cuda" or "hip"). Defaults to the device API specified by current target accelerator.

Returns:

Int: The number of available devices supporting the specified API.