# Mojo API Documentation > The Mojo API reference. This file contains all documentation content in a single document following the llmstxt.org standard. ## allgather
`allgather[dtype: DType, rank: Int, ngpus: Int](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], ngpus], output_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], (ngpus * ngpus)], rank_sigs: InlineArray[LegacyUnsafePointer[Signal], 8], ctxs: List[DeviceContext], _max_num_blocks: Optional[Int] = None)` Performs all-gather across GPUs with variadic output. Each device receives individual copies of all input buffers. The implementation automatically selects between P2P and non-P2P paths based on hardware capabilities. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - The data type of tensor elements. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): Int - Number of dimensions in input tensors. * ​ngpus ([`Int`](/mojo/stdlib/builtin/int/Int)): Int - Number of GPUs participating in all-gather. **Args:** * ​input\_buffers ([`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray)): Input buffers from each GPU. * ​output\_buffers ([`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray)): Flat array of ngpus \* ngpus output buffers. Layout: output\_buffers\[device\_idx \* ngpus + input\_idx] contains device\_idx's copy of input\_idx's data. * ​rank\_sigs ([`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray)): Signal pointers for P2P synchronization. * ​ctxs ([`List`](/mojo/stdlib/collections/list/List)): List of device contexts for participating GPUs. * ​\_max\_num\_blocks ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Maximum number of blocks for kernel launch (optional). `allgather[dtype: DType, rank: Int, ngpus: Int](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], ngpus], output_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], (ngpus * ngpus)], ctxs: List[DeviceContext])` Backward compatible version without rank\_sigs parameter. This version uses the naive implementation since we can't allocate signal buffers with proper lifetime in this function. **Deprecated:** Use the `signal_buffers` overload of `allgather` instead.
--- ## allgather (Allgather)
Multi-GPU allgather implementation that gathers values from multiple GPUs into an output buffer. This module provides an optimized implementation of allgather operations across multiple GPUs, supporting both peer-to-peer (P2P) and non-P2P communication patterns. The implementation automatically selects between approaches based on hardware capabilities: 1. P2P-based implementation (when P2P access is available): * Uses direct GPU-to-GPU memory access for better performance. * Optimized for NVLink and xGMI bandwidth utilization. * Uses vectorized memory access. 2. Non-P2P fallback implementation: * Copies data through device memory when direct GPU access isn't possible. * Simple but functional approach for systems without P2P support. ## Functions * [​`allgather`](./allgather): Performs all-gather across GPUs with variadic output.
--- ## TuningConfigAllreduce
`@register_passable(trivial)` `struct TuningConfigAllreduce` Parameters: ngpus: Number of GPUs for running allreduce. num\_bytes: Total number of input bytes supported by the config. sm\_version: SM version (as string). num\_blocks: Number of thread blocks for running allreduce. ## Fields * ​ngpus (`Int`): * ​num\_bytes (`Int`): * ​sm\_version (`StaticString`): * ​num\_blocks (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`TuningConfig`](/mojo/kernels/internal_utils/dispatch_utils/TuningConfig), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__str__` `__str__(self) -> String` **Returns:** `String`
--- ## allreduce
`allreduce[dtype: DType, rank: Int, ngpus: Int, output_lambda: OptionalReg[fn[dtype: DType, rank: Int, width: Int, *, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None] = None, pdl_level: PDLLevel = PDLLevel(), *, use_multimem: Bool = False, use_quickreduce: Bool = False](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], 1 if use_multimem else ngpus], output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], rank_sigs: InlineArray[LegacyUnsafePointer[Signal], 8], ctx: DeviceContext, _max_num_blocks: Optional[Int] = None, iteration: Int = 0)` Per-device allreduce: one instance per GPU builds its own output. High-level model * Each GPU runs one instance of this function in parallel with the others. * Every instance reads all inputs but writes only its own output buffer. * A Python-level fence is inserted across the outputs to prevent reordering. Two execution paths 1. P2P fast path (when peer access is available) * 1-stage kernel (latency-bound): each thread vector-loads from all GPUs, accumulates in higher precision, and writes directly to the result. * 2-stage kernel (bandwidth-bound): reduce-scatter then all-gather. Uses each GPU's `rank_sigs[*]` payload as a staging area for partitions. Diagram (per GPU r, 2-stage): * Stage 1: write reduced partition r into payload of `rank_sigs[r]`. * Stage 2: gather partitions from all peers' payloads into `out_r`. 2. Naive fallback (no P2P) * For GPU r: create local accumulator A\_r, allocate a temporary buffer S\_r, copy each peer input into S\_r and accumulate into A\_r, then apply the epilogue into `out_r`. Diagram (per GPU r, naive): in\_r → A\_r += in\_r; for i≠r: in\_i → tmp\_r → A\_r += tmp\_r; A\_r → out\_r Notes: * Inputs must have identical shape/dtype across GPUs. * Signal buffers must be sized at least `size_of(Signal) + payload_bytes` for the P2P 2-stage path, where `payload_bytes` equals the input tensor bytecount. * The naive path is automatically selected if P2P cannot be enabled. * The `use_multimem` parameter requires P2P access between GPUs to be enabled. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the tensor elements. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of dimensions in the tensors. * ​ngpus ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of GPUs participating in the allreduce. * ​output\_lambda ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Elementwise epilogue applied on the device result. * ​pdl\_level ([`PDLLevel`](/mojo/stdlib/gpu/primitives/grid_controls/PDLLevel)): Controls PDL behavior for P2P kernels. * ​use\_multimem ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to use multimem mode for improved performance. * ​use\_quickreduce ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, prefer the quickreduce 2-stage path when eligible. **Args:** * ​input\_buffers ([`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray)): Inputs from ALL GPUs (for P2P, these must be peer accessible). * ​output\_buffer ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output for THIS GPU. * ​rank\_sigs ([`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray)): Per-GPU Signal; header plus payload. Payload is used as scratch for the P2P 2-stage path. * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): Device context for THIS GPU (device id → rank). * ​\_max\_num\_blocks ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Optional grid limit (dispatch selects a default otherwise). * ​iteration ([`Int`](/mojo/stdlib/builtin/int/Int)): Monotonic per-call counter used to color quickreduce flags. Increment each launch; ensures barrier flags are unique across iterations to prevent reuse hazards when reusing the same signal buffers.
--- ## allreduce_2stage_quickreduce
`allreduce_2stage_quickreduce[dtype: DType, rank: Int, ngpus: Int, *, BLOCK_SIZE: Int, output_lambda: elementwise_epilogue_type, atom_size: Int](result: NDBuffer[dtype, rank, MutAnyOrigin], local_src: LegacyUnsafePointer[Scalar[dtype]], rank_sigs: InlineArray[LegacyUnsafePointer[Signal], 8], num_elements: Int, my_rank: Int, iteration: Int, num_tiles_total: Int)`
--- ## allreduce_2stage_quickreduce_tile
`allreduce_2stage_quickreduce_tile[dtype: DType, rank: Int, ngpus: Int, *, BLOCK_SIZE: Int, output_lambda: elementwise_epilogue_type, atom_size: Int, use_bufferio: Bool](result: NDBuffer[dtype, rank, MutAnyOrigin], local_src: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.GLOBAL if is_amd_gpu() else AddressSpace.GENERIC], rank_sigs: InlineArray[LegacyUnsafePointer[Signal], 8], num_elements: Int, my_rank: Int, tile: Int, num_tiles: Int, iteration: Int)`
--- ## get_sm_version
`get_sm_version() -> StaticString` **Returns:** `StaticString`
--- ## allreduce (Allreduce)
Multi-GPU allreduce implementation for efficient tensor reduction across GPUs. This module provides an optimized implementation of allreduce operations across multiple GPUs, supporting both peer-to-peer (P2P) and non-P2P communication patterns. The implementation automatically selects between two approaches based on hardware capabilities: 1. P2P-based implementation (when P2P access is available): * Uses direct GPU-to-GPU memory access for better performance * Implements both single-stage and two-stage algorithms: * Single-stage for latency-bound transfers (small tensors) * Two-stage (reduce-scatter + all-gather) for bandwidth-bound transfers (large tensors) * Optimized for NVLink bandwidth utilization * Uses vectorized memory access and higher precision accumulation 2. Non-P2P fallback implementation: * Copies data through host memory when direct GPU access isn't possible * Simple but functional approach for systems without P2P support The implementation is tuned for common GPU architectures (A100, H100) and includes parameters that can be adjusted for different hardware configurations. ## Per-Device Architecture The allreduce operation follows a per-device execution model: 1. **Single-Device Instances**: Each GPU runs its own instance of the allreduce operation. 2. **Parallel Execution**: The Python/Graph API layer is responsible for: * Creating one allreduce op instance per participating GPU. * Ensuring all instances execute in parallel. * Ensuring correctness by staging mo.fence. 3. **Device Affinity**: Each allreduce instance: * Executes on its assigned GPU (specified via device context). * Reads from all GPUs' input buffers (requires P2P access). * Writes only to its own output buffer. * Uses the same synchronization signals as other instances. 4. **Requirements**: * Peer-to-peer access must be enabled between all participating GPUs. * All instances must launch before any can complete (for synchronization). * The device context determines which GPU executes each instance. Limitations: * Number of elements must be a multiple of SIMD width. * Maximum of 8 GPUs supported. * All input/output buffers must have identical shapes. ## Visual Overview 1. 1-Stage P2P (latency-bound) Each GPU r reads its portion from every peer buffer directly (via P2P), accumulates, then writes to its result using the epilogue: ``` GPU r (result_r) src_ptrs[0] ─┐ src_ptrs[1] ─┼──► Σ (high-precision accum) ──► output_lambda ──► result_r ... ─┘ ``` Notes: * Vectorized loads from global memory on each GPU. * Good for small/latency-bound tensors. 2. 2-Stage P2P (bandwidth-bound) Stage 1 (reduce-scatter): Each GPU r reduces its assigned partition and writes into its own signal payload (the bytes after the Signal header). ``` src_ptrs[*] ──► reduce(partition r) ──► rank_sigs[r].payload (per-GPU) ``` Stage 2 (all-gather): Each GPU r gathers all partitions from peers' payloads and writes them to its result using the epilogue. ``` [payload_0], [payload_1], ..., [payload_{ngpus-1}] ──► result_r (via output_lambda) ``` For the naive allreduce (no P2P) per-device flow and staging details, see the `_allreduce_naive_single` docstring in this file. ## `comptime` values ### `allreduce_table` `comptime allreduce_table = Table[TuningConfigAllreduce](List[TuningConfigAllreduce](TuningConfigAllreduce(-1, -1, "sm_90a", 216), TuningConfigAllreduce(4, 134217728, "sm_90a", 232), TuningConfigAllreduce(-1, -1, "sm_100a", 512), TuningConfigAllreduce(2, 8388608, "sm_100a", 512), TuningConfigAllreduce(2, 16777216, "sm_100a", 512), TuningConfigAllreduce(2, 33554432, "sm_100a", 512), TuningConfigAllreduce(2, 67108864, "sm_100a", 512), TuningConfigAllreduce(2, 134217728, "sm_100a", 512), TuningConfigAllreduce(4, 8388608, "sm_100a", 512), TuningConfigAllreduce(4, 16777216, "sm_100a", 512), TuningConfigAllreduce(4, 33554432, "sm_100a", 512), TuningConfigAllreduce(4, 67108864, "sm_100a", 512), TuningConfigAllreduce(4, 134217728, "sm_100a", 512), Tuple[]()), "allreduce_table")` ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[dtype: DType, rank: Int, width: Int, *, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None` ## Structs * [​`TuningConfigAllreduce`](./TuningConfigAllreduce): Parameters: ngpus: Number of GPUs for running allreduce. num\_bytes: Total number of input bytes supported by the config. sm\_version: SM version (as string). num\_blocks: Number of thread blocks for running allreduce. ## Functions * [​`allreduce`](./allreduce): Per-device allreduce: one instance per GPU builds its own output. * [​`allreduce_2stage_quickreduce`](./allreduce_2stage_quickreduce): * [​`allreduce_2stage_quickreduce_tile`](./allreduce_2stage_quickreduce_tile): * [​`get_sm_version`](./get_sm_version):
--- ## comm
Provides communication primitives for GPUs. This package includes functions for sending and receiving data between GPUs, as well as for synchronizing threads across GPUs. ## Packages * [​`vendor`](./vendor/): ## Modules * [​`allgather`](./allgather/): Multi-GPU allgather implementation that gathers values from multiple GPUs into an output buffer. * [​`allreduce`](./allreduce/): Multi-GPU allreduce implementation for efficient tensor reduction across GPUs. * [​`sync`](./sync/):
--- ## Signal
`@register_passable(trivial)` `struct Signal` A synchronization primitive for coordinating GPU thread blocks across multiple devices. This struct provides counter-based synchronization between thread blocks on different GPUs. It maintains two sets of counters: 1. self\_counter: Used by blocks on the current GPU to signal their progress 2. peer\_counter: Used to track progress of blocks on other GPUs Note: The counters use unsigned integers that may overflow, but this is safe since unsigned integer overflow has well-defined behavior. ## Fields * ​self\_counter (`StaticTuple[StaticTuple[UInt32, 8], 512]`): A 2D array of counters with shape (MAX\_NUM\_BLOCKS\_UPPER\_BOUND, MAX\_GPUS). Each counter tracks the progress of a specific thread block on the current GPU. Thread blocks increment their corresponding counter to signal completion of a phase, allowing other GPUs to detect when synchronization points are reached. The counters use atomic operations to ensure proper synchronization across devices. * ​peer\_counter (`StaticTuple[StaticTuple[StaticTuple[UInt32, 8], 512], 2]`): A 3D array of counters with shape (2, MAX\_NUM\_BLOCKS\_UPPER\_BOUND, MAX\_GPUS). Contains two sets of counters to handle two synchronization points safely. The dual counter design prevents race conditions where a peer block arrives at the second sync point before the current block passes the first sync point. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `flag_t` `comptime flag_t = DType.uint32`
--- ## can_enable_p2p
`can_enable_p2p() -> Bool` If peer-to-peer access is supported, enables it between all GPU pairs. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if P2P access is possible between all GPU pairs, False otherwise.
--- ## group_end
`group_end()`
--- ## group_start
`group_start()`
--- ## sync
## `comptime` values ### `MAX_GPUS` `comptime MAX_GPUS = 8` Maximum number of GPUs supported in the allreduce implementation. This constant sets the upper bound for the number of GPUS supported in this algorithm. ### `MAX_NUM_BLOCKS_UPPER_BOUND` `comptime MAX_NUM_BLOCKS_UPPER_BOUND = 512` Maximum number of thread blocks to use for reduction kernels. This value has been empirically optimized through grid search across different GPU architectures. While this value is optimal for A100 GPUs, H100 GPUs may benefit from more blocks to fully saturate NVLink bandwidth. ## Structs * [​`Signal`](./Signal): A synchronization primitive for coordinating GPU thread blocks across multiple devices. ## Functions * [​`can_enable_p2p`](./can_enable_p2p): If peer-to-peer access is supported, enables it between all GPU pairs. * [​`group_end`](./group_end): * [​`group_start`](./group_start):
--- ## Communicators
`struct Communicators` ## Fields * ​ngpus (`Int`): * ​comms (`InlineArray[ncclComm_t, 8]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True`
--- ## allgather (Ccl)
`allgather[dtype: DType, rank: Int, ngpus: Int](inputs: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], ngpus], outputs: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], (ngpus * ngpus)], list_of_ctx: List[DeviceContext])`
--- ## allreduce (Ccl)
`allreduce[dtype: DType, rank: Int, ngpus: Int, output_lambda: OptionalReg[fn[dtype: DType, rank: Int, width: Int, *, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None] = None, pdl_level: PDLLevel = PDLLevel(), *, use_multimem: Bool = False, use_quickreduce: Bool = False](input_buffers: InlineArray[NDBuffer[dtype, rank, MutAnyOrigin], 1 if use_multimem else ngpus], output_buffer: NDBuffer[dtype, rank, MutAnyOrigin], rank_sigs: InlineArray[LegacyUnsafePointer[Signal], 8], ctx: DeviceContext, _max_num_blocks: Optional[Int] = None, iteration: Int = 0)` Per-GPU allreduce for use in multi-threaded contexts. Currently requires prior single-threaded call to init\_comms, as thread-safe version not yet implemented.
--- ## group
`group() -> _Group` **Returns:** `_Group`
--- ## ccl
## `comptime` values ### `CCL_LIBRARY` `comptime CCL_LIBRARY = _Global["CCL_LIBRARY", _init_ccl_dylib]` ### `CCLAllGatherFn` `comptime CCLAllGatherFn = fn(LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType], Int, ncclDataType_t, LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType]) -> ncclResult_t` ### `CCLAllReduceFn` `comptime CCLAllReduceFn = fn(LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType], Int, ncclDataType_t, ncclRedOp_t, LegacyUnsafePointer[NoneType], LegacyUnsafePointer[NoneType]) -> ncclResult_t` ### `NCCL_LIBRARY_PATHS` `comptime NCCL_LIBRARY_PATHS = List[Path]("libnccl.so", "libnccl.so.2", "/usr/lib/x86_64-linux-gnu/libnccl.so", "/usr/lib/x86_64-linux-gnu/libnccl.so.2", Tuple[]())` ### `ncclComm_t` `comptime ncclComm_t = LegacyOpaquePointer` ### `RCCL_LIBRARY_PATHS` `comptime RCCL_LIBRARY_PATHS = List[Path]("librccl.so", "librccl.so.1", "/opt/rocm/lib/librccl.so", "/opt/rocm/lib/librccl.so.1", Tuple[]())` ## Structs * [​`Communicators`](./Communicators): * [​`ncclDataType_t`](./ncclDataType_t): * [​`ncclRedOp_t`](./ncclRedOp_t): * [​`ncclResult_t`](./ncclResult_t): ## Functions * [​`allgather`](./allgather): * [​`allreduce`](./allreduce): Per-GPU allreduce for use in multi-threaded contexts. * [​`group`](./group): * [​`init_comms`](./init_comms): Pre-initialize NCCL/RCCL communicators. * [​`is_allgather_available`](./is_allgather_available): * [​`is_allreduce_available`](./is_allreduce_available): * [​`ncclCommInitAll`](./ncclCommInitAll):
--- ## init_comms
`init_comms(ngpus: Int)` Pre-initialize NCCL/RCCL communicators. Must be called from a single thread before using allreduce from multiple threads. This ensures thread-safe initialization since ncclCommInitAll is not designed for concurrent calls.
--- ## is_allgather_available
`is_allgather_available() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## is_allreduce_available
`is_allreduce_available() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## ncclCommInitAll
`ncclCommInitAll(comms: LegacyUnsafePointer[ncclComm_t], ndev: Int, devlist: LegacyUnsafePointer[Int32]) -> ncclResult_t` **Returns:** `ncclResult_t`
--- ## ncclDataType_t
`@register_passable(trivial)` `struct ncclDataType_t` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ncclBfloat16` `comptime ncclBfloat16 = ncclDataType_t(9)` ### `ncclFloat16` `comptime ncclFloat16 = ncclDataType_t(6)` ### `ncclFloat32` `comptime ncclFloat32 = ncclDataType_t(7)` ## Methods ### `__init__` `__init__(value: Int) -> Self`
--- ## ncclRedOp_t
`@register_passable(trivial)` `struct ncclRedOp_t` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ncclSum` `comptime ncclSum = ncclRedOp_t(0)` ## Methods ### `__init__` `__init__(value: Int) -> Self`
--- ## ncclResult_t
`@register_passable(trivial)` `struct ncclResult_t` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ncclSuccess` `comptime ncclSuccess = ncclResult_t(0)` ## Methods ### `__init__` `__init__(value: Int) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `write_to` `write_to(self, mut writer: T)`
--- ## vendor
## Modules * [​`ccl`](./ccl/):
--- ## extensibility
Includes the tensor package. ## Packages * [​`tensor`](./tensor/): APIs to create and manage tensors in a graph.
--- ## tensor
APIs to create and manage tensors in a graph. ## Modules * [​`io_spec`](./io_spec/): * [​`managed_tensor_slice`](./managed_tensor_slice/): Implements the `ManagedTensorSlice` type - a view of a tensor that doesn't own the underlying data. This type is used to build custom graph operations. * [​`operation_traits`](./operation_traits/): * [​`tensor_spec`](./tensor_spec/): You can import these APIs from the `max.tensor` package. For example: * [​`transitional`](./transitional/): Utilities for transitional period during NDBuffer deprecation.
--- ## IO
`@register_passable(trivial)` `struct IO` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `FusedInput` `comptime FusedInput = IO(2)` ### `FusedOutput` `comptime FusedOutput = IO(3)` ### `Input` `comptime Input = IO(1)` ### `Output` `comptime Output = IO(0)` ### `Unknown` `comptime Unknown = IO(-1)` ## Methods ### `__init__` `__init__(value: Int) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## IOSpec
`@register_passable(trivial)` `struct IOSpec[mut: Bool, input: IO]` Parameter used to encode whether a particular tensor argument to a DPS kernel is an output, input, or mutable input. ```mojo Input == IOSpec[False, IO.Input]() Output == IOSpec[True, IO.Output]() MutableInput == IOSpec[True, IO.Input]() FusedInput == IOSpec[False, IO.FusedInput]() FusedOutput == IOSpec[True, IO.FusedOutput]() ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True`
--- ## io_spec
## `comptime` values ### `FusedInput` `comptime FusedInput = IOSpec[False, IO.FusedInput]()` ### `FusedOutput` `comptime FusedOutput = IOSpec[True, IO.FusedOutput]()` ### `Input` `comptime Input = IOSpec[False, IO.Input]()` ### `IOUnknown` `comptime IOUnknown = IOSpec[True, IO.Unknown]()` ### `MutableInput` `comptime MutableInput = IOSpec[True, IO.Input]()` ### `Output` `comptime Output = IOSpec[True, IO.Output]()` ## Structs * [​`IO`](./IO): * [​`IOSpec`](./IOSpec): Parameter used to encode whether a particular tensor argument to a DPS kernel is an output, input, or mutable input.
--- ## ManagedTensorSlice
`@register_passable(trivial)` `struct ManagedTensorSlice[mut: Bool, input: IO, dtype: DType, rank: Int, //, io_spec: IOSpec[mut, input], *, static_spec: StaticTensorSpec[dtype, rank]]` A view of a tensor that does not own the underlying allocated pointer. When the object lifetime ends it does not free the underlying pointer. Conversely, if a `ManagedTensorSlice` is created, it will not extend the life of the underlying pointer. Therefore, the user must take care to keep the pointer alive until the last use of a `ManagedTensorSlice` instance. This class is useful for writing custom operations where memory is managed by an external runtime like in MAX's inference stack. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `address_space` `comptime address_space = static_spec.address_space` ### `alignment` `comptime alignment = static_spec.alignment` ### `device_type` `comptime device_type = LayoutTensor[dtype, static_spec.to_layout[dtype, rank](), MutAnyOrigin]` ### `exclusive` `comptime exclusive = static_spec.exclusive` ## Methods ### `__init__` `__init__(ptr: LegacyUnsafePointer[Scalar[dtype]], slices: InlineArray[Slice, rank], slicer_spec: RuntimeTensorSpec[dtype, rank]) -> Self` Initializes a ManagedTensorSlice from a pointer, array of slices and tensor spec. In general, custom operations should not create `ManagedTensorSlice` instances, but instead use the ones provided by the MAX inference engine. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype]], shape: IndexList[rank]) -> Self` Initializes a ManagedTensorSlice from a pointer and shape. In general, custom operations should not create `ManagedTensorSlice` instances, but instead use the ones provided by the MAX inference engine. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype]], shape: IndexList[rank], strides: IndexList[rank]) -> Self` Initializes a ManagedTensorSlice from a pointer, shape, and strides. In general, custom operations should not create `ManagedTensorSlice` instances, but instead use the ones provided by the MAX inference engine. ### `__getitem__` `__getitem__(self, indices: IndexList[rank]) -> Scalar[dtype]` Gets the value at the specified indices. **Args:** * ​indices ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The indices of the value to retrieve. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The value at the specified indices. `__getitem__(self, *indices: Int) -> Scalar[dtype]` Gets the value at the specified indices. **Args:** * ​\*indices ([`Int`](/mojo/stdlib/builtin/int/Int)): The indices of the value to retrieve. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The value at the specified indices. ### `__setitem__` `__setitem__(self, *indices: Int, *, val: Scalar[dtype])` Stores the value at the specified indices. **Args:** * ​\*indices ([`Int`](/mojo/stdlib/builtin/int/Int)): The indices of the value to store. * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value to store. `__setitem__(self, indices: IndexList[rank], val: Scalar[dtype])` Stores the value at the specified indices. **Args:** * ​indices ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The indices of the value to store. * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value to store. ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `spec` `spec(self) -> RuntimeTensorSpec[dtype, rank]` Gets the `TensorSpec` of this tensor slice, which provides meta-data about the tensor slice. **Returns:** [`RuntimeTensorSpec`](/mojo/tensor/tensor_spec/RuntimeTensorSpec): The static `TensorSpec` for this tensor slice. ### `shape` `shape(self) -> IndexList[rank]` Gets the shape of this tensor slice, as an `IndexList`. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The shape of this tensor slice. ### `dim_size` `dim_size(self, index: Int) -> Int` Gets the size of a given dimension of this tensor slice using a run time value. **Args:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): The zero-based index of the dimension. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the tensor slice in the given dimension. `dim_size[index: Int](self) -> Int` Gets the size of a given dimension of this tensor slice using a compile time value. **Parameters:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): The zero-based index of the dimension. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the tensor slice in the given dimension. ### `strides` `strides(self) -> IndexList[rank]` Gets the strides of this tensor slice, as an `IndexList`. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The strides of this tensor slice. ### `stride_length` `stride_length(self, index: Int) -> Int` Gets the length of the stride of a given dimension of this tensor slice using a run time value. **Args:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): The zero-based index of the dimension. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the tensor slice in the given dimension. `stride_length[index: Int](self) -> Int` Gets the length of the stride of a given dimension of this tensor slice using a compile time value. **Parameters:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): The zero-based index of the dimension. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the tensor slice in the given dimension. ### `size` `size(self) -> Int` Computes the tensor slice's number of elements. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total number of elements in the tensor slice. ### `unsafe_ptr` `unsafe_ptr[_dtype: DType = dtype](self) -> LegacyUnsafePointer[Scalar[_dtype]]` Get the pointer stored in this tensor slice. Since this method obtains the pointer stored in this tensor slice, it can modify the invariants of this tensor slice and lead to unexpected behavior. It should be used with caution. **Parameters:** * ​\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The type of the `UnsafePointer` in this tensor slice. **Returns:** `LegacyUnsafePointer`: The `UnsafePointer` which contains the data for this tensor slice. ### `load` `load[width: Int, _rank: Int, element_alignment: Int = 1](self, index: IndexList[_rank]) -> SIMD[dtype, width]` Gets data from this tensor slice as a `SIMD`. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the `SIMD` value. This must be large enough to contain the data from this tensor slice. * ​\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the tensor slice. * ​element\_alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Indicate the alignment of the pointer stored to memory. This is needed to issue vector load for GPUs with strict alignment requirements. **Args:** * ​index ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): An `IndexList` of size `_rank` to indicate the dimension of the tensor slice to obtain data from. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Data from this tensor slice at dimension `index`. ### `store` `store[width: Int, _rank: Int, element_alignment: Int = 1](self: ManagedTensorSlice[io_spec, static_spec=static_spec], index: IndexList[_rank], val: SIMD[dtype, width])` Sets data in this tensor slice from a `SIMD`. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the `SIMD` value. * ​\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the tensor slice. * ​element\_alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Indicate the alignment of the pointer stored to memory. This is needed to issue vector store for GPUs with strict alignment requirements. **Args:** * ​index ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): An `IndexList` of size `_rank` to indicate the dimension of the tensor slice to set data in. * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The data to set into this tensor slice. ### `with_layout` `with_layout[new_rank: Int, //, new_static_shape: DimList, new_static_strides: DimList](self, new_runtime_shape: IndexList[new_rank], new_runtime_strides: IndexList[new_rank], offset_ptr: OptionalReg[LegacyUnsafePointer[Scalar[dtype]]] = None) -> ManagedTensorSlice[io_spec, static_spec=static_spec.with_layout[dtype, rank, new_rank](new_static_shape, new_static_strides)]` **Returns:** [`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice) ### `to_layout_tensor` `to_layout_tensor(self) -> LayoutTensor[dtype, static_spec.to_layout[dtype, rank](), MutAnyOrigin]` **Returns:** `LayoutTensor` ### `write_to` `write_to(self, mut writer: T)` Formats this buffer to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `__repr__` `__repr__(self) -> String` Gets the buffer as a string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A compact string representation of the buffer. ### `__str__` `__str__(self) -> String` Gets the buffer as a string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A compact string of the buffer.
--- ## VariadicTensors
`@register_passable(trivial)` `struct VariadicTensors[mut: Bool, input: IO, //, dtype: DType, rank: Int, size: Int, io_spec: IOSpec[mut, input], *, static_specs: StaticTuple[StaticTensorSpec[dtype, rank], size]]` A tuple-like container of tensors representing variadic arguments from the graph compiler. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__getitem__` `__getitem__[index: Int](self) -> ManagedTensorSlice[io_spec, static_spec=static_specs.__getitem__[size, Int](index)]` Returns the tensor at the given position in the variadic argument argument pack. **Parameters:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): The index into the variadic tensor arguments. **Returns:** [`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice): The tensor at the specified index. ### `__len__` `__len__(self) -> Int` Returns the number of variadic arguments in the pack. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of variadic arguments.
--- ## foreach
`foreach[dtype: DType, rank: Int, //, func: fn[width: Int, element_alignment: Int](IndexList[rank]) capturing -> SIMD[dtype, width], *, target: StringSlice[StaticConstantOrigin] = "cpu", simd_width: Int = get_kernel_simd_width[dtype, target](), _trace_name: StringSlice[StaticConstantOrigin] = "mogg.for_each"](tensor: ManagedTensorSlice[io_spec, static_spec=static_spec], ctx: DeviceContextPtr = DeviceContextPtr())` Apply the function `func` to each element of the tensor slice. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements in the tensor slice. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the tensor slice. * ​func (`fn[width: Int, element_alignment: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to apply to each element of the tensor slice. * ​target (`StringSlice`): Indicates the type of the target device (e.g. "cpu", "gpu"). * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD width for the target (usually leave this as its default value). * ​\_trace\_name (`StringSlice`): Name of the executed operation displayed in the trace\_description. **Args:** * ​tensor ([`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice)): The output tensor slice which receives the return values from `func`. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context (forward this from the custom operation). `foreach[dtype: DType, rank: Int, //, func: fn[width: Int](IndexList[rank]) capturing -> SIMD[dtype, width], out_func: fn[width: Int](IndexList[rank]) capturing -> None, *, target: StringSlice[StaticConstantOrigin] = "cpu", simd_width: Int = get_kernel_simd_width[dtype, target](), _trace_name: StringSlice[StaticConstantOrigin] = "mogg.for_each"](tensor: ManagedTensorSlice[io_spec, static_spec=static_spec], ctx: DeviceContextPtr = DeviceContextPtr())` Apply the function `func` to each element of the tensor slice. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements in the tensor slice. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the tensor slice. * ​func (`fn[width: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to apply to each element of the tensor slice. * ​out\_func (`fn[width: Int](IndexList[rank]) capturing -> None`): The function to apply on each output element. * ​target (`StringSlice`): Indicates the type of the target device (e.g. "cpu", "gpu"). * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD width for the target (usually leave this as its default value). * ​\_trace\_name (`StringSlice`): Name of the executed operation displayed in the trace\_description. **Args:** * ​tensor ([`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice)): The input tensor slice which the consumed values. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context (forward this from the custom operation). `foreach[dtype: DType, rank: Int, //, func: fn[width: Int](IndexList[rank]) capturing -> SIMD[dtype, width], *, target: StringSlice[StaticConstantOrigin] = "cpu", simd_width: Int = get_kernel_simd_width[dtype, target](), _trace_name: StringSlice[StaticConstantOrigin] = "mogg.for_each"](tensor: ManagedTensorSlice[io_spec, static_spec=static_spec], ctx: DeviceContextPtr = DeviceContextPtr())` Apply the function `func` to each element of the tensor slice. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements in the tensor slice. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the tensor slice. * ​func (`fn[width: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to apply to each element of the tensor slice. * ​target (`StringSlice`): Indicates the type of the target device (e.g. "cpu", "gpu"). * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD width for the target (usually leave this as its default value). * ​\_trace\_name (`StringSlice`): Name of the executed operation displayed in the trace\_description. **Args:** * ​tensor ([`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice)): The output tensor slice which receives the return values from `func`. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context (forward this from the custom operation).
--- ## managed_tensor_slice
Implements the `ManagedTensorSlice` type - a view of a tensor that doesn't own the underlying data. This type is used to build custom graph operations. ## `comptime` values ### `DynamicTensor` `comptime DynamicTensor[dtype: DType, rank: Int] = ManagedTensorSlice[IOUnknown, static_spec=StaticTensorSpec.create_unknown[dtype, rank]()]` #### Parameters * ​dtype (`DType`): * ​rank ([`Int`](/stdlib/builtin/int/Int)): ### `InputTensor` `comptime InputTensor = ManagedTensorSlice[Input, static_spec=?]` ### `InputVariadicTensors` `comptime InputVariadicTensors = VariadicTensors[?, ?, ?, Input, static_specs=?]` ### `OutputTensor` `comptime OutputTensor = ManagedTensorSlice[Output, static_spec=?]` ### `OutputVariadicTensors` `comptime OutputVariadicTensors = VariadicTensors[?, ?, ?, Output, static_specs=?]` ## Structs * [​`ManagedTensorSlice`](./ManagedTensorSlice): A view of a tensor that does not own the underlying allocated pointer. When the object lifetime ends it does not free the underlying pointer. Conversely, if a `ManagedTensorSlice` is created, it will not extend the life of the underlying pointer. * [​`VariadicTensors`](./VariadicTensors): A tuple-like container of tensors representing variadic arguments from the graph compiler. ## Functions * [​`foreach`](./foreach): Apply the function `func` to each element of the tensor slice. * [​`rebuild_mix_precision_static_tensor_specs_with_input_lambda`](./rebuild_mix_precision_static_tensor_specs_with_input_lambda): * [​`rebuild_static_tensor_specs_with_compute_output_lambda`](./rebuild_static_tensor_specs_with_compute_output_lambda): * [​`rebuild_static_tensor_specs_with_input_lambda`](./rebuild_static_tensor_specs_with_input_lambda): * [​`rebuild_static_tensor_specs_with_output_lambda`](./rebuild_static_tensor_specs_with_output_lambda): * [​`trace_slice_arg`](./trace_slice_arg): Helper to stringify the type and shape of a kernel argument for tracing.
--- ## rebuild_mix_precision_static_tensor_specs_with_input_lambda
`rebuild_mix_precision_static_tensor_specs_with_input_lambda[func_type: AnyTrivialRegType, //, src_dtype: DType, dst_dtype: DType, rank: Int](spec: StaticTensorSpec[src_dtype, rank], in_lambda: func_type) -> StaticTensorSpec[dst_dtype, rank]` **Returns:** [`StaticTensorSpec`](/mojo/compiler_internal/directives/StaticTensorSpec)
--- ## rebuild_static_tensor_specs_with_compute_output_lambda
`rebuild_static_tensor_specs_with_compute_output_lambda[func_type: AnyTrivialRegType, //, dtype: DType, rank: Int](spec: StaticTensorSpec[dtype, rank], out_compute_lambda: func_type) -> StaticTensorSpec[dtype, rank]` **Returns:** [`StaticTensorSpec`](/mojo/compiler_internal/directives/StaticTensorSpec)
--- ## rebuild_static_tensor_specs_with_input_lambda
`rebuild_static_tensor_specs_with_input_lambda[func_type: AnyTrivialRegType, //, dtype: DType, rank: Int](spec: StaticTensorSpec[dtype, rank], in_lambda: func_type) -> StaticTensorSpec[dtype, rank]` **Returns:** [`StaticTensorSpec`](/mojo/compiler_internal/directives/StaticTensorSpec)
--- ## rebuild_static_tensor_specs_with_output_lambda
`rebuild_static_tensor_specs_with_output_lambda[func_type: AnyTrivialRegType, //, dtype: DType, rank: Int](spec: StaticTensorSpec[dtype, rank], out_lambda: func_type) -> StaticTensorSpec[dtype, rank]` **Returns:** [`StaticTensorSpec`](/mojo/compiler_internal/directives/StaticTensorSpec)
--- ## trace_slice_arg
`trace_slice_arg(name: String, buf: ManagedTensorSlice[io_spec, static_spec=static_spec]) -> String` Helper to stringify the type and shape of a kernel argument for tracing. **Args:** * ​name ([`String`](/mojo/stdlib/collections/string/string/String)): The name of the argument. * ​buf ([`ManagedTensorSlice`](/mojo/tensor/managed_tensor_slice/ManagedTensorSlice)): The NDBuffer to trace. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the buffer with its shape and data type.
--- ## ElementwiseBinaryComparisonOp
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `elementwise` `static elementwise[dtype: DType, width: Int](lhs: SIMD[dtype, width], rhs: SIMD[dtype, width]) -> SIMD[DType.bool, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## ElementwiseBinaryOp
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `elementwise` `static elementwise[dtype: DType, width: Int](lhs: SIMD[dtype, width], rhs: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## ElementwiseUnaryMixedOp
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `elementwise` `static elementwise[dtype: DType, out_dtype: DType, width: Int](x: SIMD[dtype, width]) -> SIMD[out_dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## ElementwiseUnaryOp
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `elementwise` `static elementwise[dtype: DType, width: Int](x: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## operation_traits
## Traits * [​`ElementwiseBinaryComparisonOp`](./ElementwiseBinaryComparisonOp): * [​`ElementwiseBinaryOp`](./ElementwiseBinaryOp): * [​`ElementwiseUnaryMixedOp`](./ElementwiseUnaryMixedOp): * [​`ElementwiseUnaryOp`](./ElementwiseUnaryOp):
--- ## RuntimeTensorSpec
`@register_passable(trivial)` `struct RuntimeTensorSpec[dtype: DType, rank: Int]` ## Fields * ​shape (`IndexList[rank]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__getitem__` `__getitem__(self, idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `bytecount` `bytecount(self) -> Int` Gets the total byte count. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total byte count.
--- ## tensor_spec
You can import these APIs from the `max.tensor` package. For example: ```mojo from max.tensor import RuntimeTensorSpec ``` ## Structs * [​`RuntimeTensorSpec`](./RuntimeTensorSpec):
--- ## transitional
Utilities for transitional period during NDBuffer deprecation. ## Functions * [​`managed_tensor_slice_to_ndbuffer`](./managed_tensor_slice_to_ndbuffer):
--- ## managed_tensor_slice_to_ndbuffer
`managed_tensor_slice_to_ndbuffer[spec: StaticTensorSpec[dtype, rank], //](tensor: ManagedTensorSlice[io_spec, static_spec=spec]) -> NDBuffer[dtype, rank, MutAnyOrigin, spec.shape, spec.strides, address_space=spec.address_space]` **Returns:** `NDBuffer`
--- ## kv_cache
Contains implementations for several types of key-value caches. [KV caches](/glossary/ai/kv-cache) are used in transformer models to store key-value tensors output from self-attention layers. These APIs are used in the higher-level functions in the [`nn`](/mojo/kernels/nn) package. ## Modules * [​`types`](./types/): This module contains the types for the key-value cache APIs.
--- ## ContinuousBatchingKVCache
`@register_passable(trivial)` `struct ContinuousBatchingKVCache[dtype_: DType, kv_params_: KVCacheStaticParams]` Wrapper for the ContinuousKVCache of a given layer in the transformer model. This abstracts the Pointer indirection for accessing the ContinuousKVCache for a given batch entry. THIS IS THE TYPE THAT IS PASSED TO KV PROJECTION AND FLASH ATTENTION KERNELS. ## Parameters * ​dtype\_ ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the kv-cache. * ​kv\_params\_ ([`KVCacheStaticParams`](/mojo/kernels/kv_cache/types/KVCacheStaticParams)): The kv-cache static parameters. ## Fields * ​blocks (`ContinuousBatchingKVCache[dtype_, kv_params_].blocks_type`): * ​cache\_lengths (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​lookup\_table (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​max\_seq\_length (`UInt32`): * ​max\_cache\_length (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`KVCacheT`](/mojo/kernels/kv_cache/types/KVCacheT), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocks_layout` `comptime blocks_layout = Layout.row_major(ContinuousBatchingKVCache[dtype_, kv_params_].blocks_shape)` ### `blocks_shape` `comptime blocks_shape = IntTuple(-1, -1, Int(ContinuousBatchingKVCache[dtype_, kv_params_].kv_params), Int(ContinuousBatchingKVCache[dtype_, kv_params_].kv_params))` ### `blocks_type` `comptime blocks_type = LayoutTensor[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_layout, MutAnyOrigin]` ### `device_type` `comptime device_type = ContinuousBatchingKVCache[dtype_, kv_params_]` ### `dtype` `comptime dtype = dtype_` ### `kv_params` `comptime kv_params = kv_params_` ### `page_size_` `comptime page_size_ = 0` ## Methods ### `__init__` `__init__(blocks: LayoutTensor[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, ContinuousBatchingKVCache[dtype_, kv_params_].blocks_layout, MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** `String` ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** `String` ### `max_tile_size` `static max_tile_size() -> Int` Returns the maximum tile size for the KVCache. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `cache_lengths_nd` `cache_lengths_nd(self) -> LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `cache_length` `cache_length(self, batch_idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `load` `load[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `store` `store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, size])` ### `empty_cache` `empty_cache(self) -> Bool` Returns true if the cache\_lengths for all requests is 0, false otherwise. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `max_prompt_length` `max_prompt_length(self) -> UInt32` Returns the maximum sequence length across all batches of the current request. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `max_context_length` `max_context_length(self) -> UInt32` Returns the maximum cache length used across all batches of the current request. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `row_idx` `row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[BN: Int, swizzle_mode: TensorMapSwizzle](self, ctx: DeviceContext) -> TMATensorTile[ContinuousBatchingKVCache[dtype_, kv_params_].dtype, _split_last_layout[ContinuousBatchingKVCache[dtype_, kv_params_].dtype](IndexList[3, DType.int64](BN, 1, Int(ContinuousBatchingKVCache[dtype_, kv_params_].kv_params), Tuple[]()), swizzle_mode, True), _ragged_desc_layout[ContinuousBatchingKVCache[dtype_, kv_params_].dtype](IndexList[3, DType.int64](BN, 1, Int(ContinuousBatchingKVCache[dtype_, kv_params_].kv_params), Tuple[]()), swizzle_mode)]` Creates a TMA tile for this KV cache. **Returns:** `TMATensorTile` ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> LegacyUnsafePointer[Scalar[ContinuousBatchingKVCache[dtype_, kv_params_].dtype]]` **Returns:** `LegacyUnsafePointer`
--- ## ContinuousBatchingKVCacheCollection
`struct ContinuousBatchingKVCacheCollection[dtype_: DType, kv_params_: KVCacheStaticParams]` This is a "view" of the cache for the given sequences in the batch. This object does not own the underlying buffers in k\_cache and v\_cache, it's borrowing them from the BlockWrappers in our KVCacheManager. It does own the Pointer\[NDBuffer\[dtype, 3]] and valid\_lengths buffer ## Parameters * ​dtype\_ ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the kv-cache. * ​kv\_params\_ ([`KVCacheStaticParams`](/mojo/kernels/kv_cache/types/KVCacheStaticParams)): The kv-cache static parameters. ## Fields * ​cache\_lengths (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​lookup\_table (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​blocks (`ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_type`): * ​max\_seq\_length (`UInt32`): * ​max\_cache\_length (`UInt32`): * ​kv\_cache\_dynamic\_shape (`IndexList[4]`): * ​kv\_cache\_dynamic\_strides (`IndexList[4]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`KVCollectionT`](/mojo/kernels/kv_cache/types/KVCollectionT), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocks_layout` `comptime blocks_layout = Layout.row_major(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_shape)` ### `blocks_shape` `comptime blocks_shape = IntTuple(-1, -1, -1, -1, Int(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params), Int(ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params))` ### `blocks_type` `comptime blocks_type = LayoutTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, ContinuousBatchingKVCacheCollection[dtype_, kv_params_].blocks_layout, MutAnyOrigin]` ### `CacheType` `comptime CacheType = ContinuousBatchingKVCache[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, ContinuousBatchingKVCacheCollection[dtype_, kv_params_].kv_params]` ### `dtype` `comptime dtype = dtype_` ### `kv_params` `comptime kv_params = kv_params_` ### `name_str` `comptime name_str = "continuous_batching"` ## Methods ### `__init__` `__init__(out self, blocks: LayoutTensor[ContinuousBatchingKVCacheCollection[dtype_, kv_params_].dtype, Layout.row_major[6](), MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32)` ### `get_key_cache` `get_key_cache(self, layer_idx: Int) -> ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType` **Returns:** `ContinuousBatchingKVCacheCollection` ### `get_value_cache` `get_value_cache(self, layer_idx: Int) -> ContinuousBatchingKVCacheCollection[dtype_, kv_params_].CacheType` **Returns:** `ContinuousBatchingKVCacheCollection` ### `cache_length` `cache_length(self, bs_idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## KVCacheStaticParams
`@register_passable(trivial)` `struct KVCacheStaticParams` ## Fields * ​num\_heads (`UInt`): * ​head\_size (`UInt`): * ​is\_mla (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(num_heads: UInt, head_size: UInt, is_mla: Bool = False) -> Self` Initialize KVCacheStaticParams. Args: num\_heads (UInt): Number of attention heads. head\_size (UInt): Size of each attention head. is\_mla (Bool, optional): Whether to use Multi-Linear Attention (MLA) mode. If true, we only store k cache. If False, we store k and v cache. Defaults to False. ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## KVCacheT
Trait for different KVCache types and implementations. Represents a single (key or value) cache. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ### `dtype` `comptime dtype` ### `kv_params` `comptime kv_params` ### `page_size_` `comptime page_size_` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `cache_lengths_nd` `cache_lengths_nd(self: _Self) -> LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]` Returns the cache lengths as a NDBuffer. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `cache_length` `cache_length(self: _Self, batch_idx: Int) -> Int` Returns the length of the cache for a given batch index. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `load` `load[width: Int](self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[_Self.dtype, width]` Loads an element from the given index. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `store` `store(self: _Self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[_Self.dtype, size])` Stores an element at the given index. ### `empty_cache` `empty_cache(self: _Self) -> Bool` Returns true if the cache\_lengths for all requests is 0, false otherwise. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `max_prompt_length` `max_prompt_length(self: _Self) -> UInt32` Returns the maximum sequence length across all batches of the current request. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `max_context_length` `max_context_length(self: _Self) -> UInt32` Returns the maximum cache length used across all batches of the current request. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self: _Self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> LegacyUnsafePointer[Scalar[_Self.dtype]]` Returns a LayoutTensor pointing to the KVCache block at the given index. Paged KVCache implementations must have a block\_size which is a multiple of the and greater than the layout's first dimension. **Returns:** `LegacyUnsafePointer` ### `max_tile_size` `static max_tile_size() -> Int` Returns the maximum tile size for the KVCache. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `row_idx` `row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[BN: Int, swizzle_mode: TensorMapSwizzle](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, _split_last_layout[_Self.dtype](IndexList[3, DType.int64](BN, 1, Int(_Self.kv_params), Tuple[]()), swizzle_mode, True), _ragged_desc_layout[_Self.dtype](IndexList[3, DType.int64](BN, 1, Int(_Self.kv_params), Tuple[]()), swizzle_mode)]` Creates a TMA tile for this KV cache. **Returns:** `TMATensorTile` ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** `String`: The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** `String`: The device type's name. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## KVCollectionT
Trait for a pair of caches (keys and values). ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `CacheType` `comptime CacheType` ### `dtype` `comptime dtype` ### `kv_params` `comptime kv_params` ### `name_str` `comptime name_str` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `get_key_cache` `get_key_cache(self: _Self, layer_idx: Int) -> _Self.CacheType` **Returns:** `_Self.CacheType` ### `get_value_cache` `get_value_cache(self: _Self, layer_idx: Int) -> _Self.CacheType` **Returns:** `_Self.CacheType` ### `cache_length` `cache_length(self: _Self, bs_idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## PagedKVCache
`@register_passable(trivial)` `struct PagedKVCache[dtype_: DType, kv_params_: KVCacheStaticParams, page_size: Int]` The PagedKVCache is a wrapper around the KVCache blocks for a given layer. It is used to access the KVCache blocks for PagedAttention. ## Parameters * ​dtype\_ ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the kv-cache. * ​kv\_params\_ ([`KVCacheStaticParams`](/mojo/kernels/kv_cache/types/KVCacheStaticParams)): The kv-cache static parameters. * ​page\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the page. ## Fields * ​blocks (`PagedKVCache[dtype_, kv_params_, page_size].blocks_type`): * ​cache\_lengths (`LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]`): * ​lookup\_table (`LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin]`): * ​max\_seq\_length (`UInt32`): * ​max\_cache\_length (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`KVCacheT`](/mojo/kernels/kv_cache/types/KVCacheT), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocks_layout` `comptime blocks_layout = Layout.row_major(PagedKVCache[dtype_, kv_params_, page_size].blocks_shape)` ### `blocks_shape` `comptime blocks_shape = IntTuple(-1, page_size, Int(PagedKVCache[dtype_, kv_params_, page_size].kv_params), Int(PagedKVCache[dtype_, kv_params_, page_size].kv_params))` ### `blocks_type` `comptime blocks_type = LayoutTensor[PagedKVCache[dtype_, kv_params_, page_size].dtype, PagedKVCache[dtype_, kv_params_, page_size].blocks_layout, MutAnyOrigin]` ### `device_type` `comptime device_type = PagedKVCache[dtype_, kv_params_, page_size]` ### `dtype` `comptime dtype = dtype_` ### `kv_params` `comptime kv_params = kv_params_` ### `page_size_` `comptime page_size_ = page_size` ## Methods ### `__init__` `__init__(blocks: LayoutTensor[PagedKVCache[dtype_, kv_params_, page_size].dtype, PagedKVCache[dtype_, kv_params_, page_size].blocks_layout, MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** `String` ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** `String` ### `max_tile_size` `static max_tile_size() -> Int` Returns the maximum tile size for the KVCache. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `cache_lengths_nd` `cache_lengths_nd(self) -> LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `cache_length` `cache_length(self, batch_idx: Int) -> Int` Returns the length of the cache for a given batch index. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `row_idx` `row_idx(self, batch_idx: UInt32, tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[BN: Int, swizzle_mode: TensorMapSwizzle](self, ctx: DeviceContext) -> TMATensorTile[PagedKVCache[dtype_, kv_params_, page_size].dtype, _split_last_layout[PagedKVCache[dtype_, kv_params_, page_size].dtype](IndexList[3, DType.int64](BN, 1, Int(PagedKVCache[dtype_, kv_params_, page_size].kv_params), Tuple[]()), swizzle_mode, True), _ragged_desc_layout[PagedKVCache[dtype_, kv_params_, page_size].dtype](IndexList[3, DType.int64](BN, 1, Int(PagedKVCache[dtype_, kv_params_, page_size].kv_params), Tuple[]()), swizzle_mode)]` Creates a TMA tile for this KV cache. **Returns:** `TMATensorTile` ### `load` `load[width: Int](self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int) -> SIMD[PagedKVCache[dtype_, kv_params_, page_size].dtype, width]` Loads an element from the given index. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `store` `store(self, bs: Int, head_idx: Int, tok_idx: Int, head_dim_idx: Int, val: SIMD[PagedKVCache[dtype_, kv_params_, page_size].dtype, size])` Stores an element at the given index. ### `empty_cache` `empty_cache(self) -> Bool` Returns true if the cache\_lengths for all requests is 0, false otherwise. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `max_prompt_length` `max_prompt_length(self) -> UInt32` Returns the maximum sequence length across all batches of the current request. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `max_context_length` `max_context_length(self) -> UInt32` Returns the maximum cache length used across all batches of the current request. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self, batch_idx: Int, start_tok_idx: Int, head_idx: Int, head_dim_idx: Int = 0) -> LegacyUnsafePointer[Scalar[PagedKVCache[dtype_, kv_params_, page_size].dtype]]` **Returns:** `LegacyUnsafePointer`
--- ## PagedKVCacheCollection
`struct PagedKVCacheCollection[dtype_: DType, kv_params_: KVCacheStaticParams, page_size: Int]` ## Fields * ​blocks (`PagedKVCacheCollection[dtype_, kv_params_, page_size].blocks_type`): * ​cache\_lengths (`PagedKVCacheCollection[dtype_, kv_params_, page_size].cache_lengths_type`): * ​lookup\_table (`PagedKVCacheCollection[dtype_, kv_params_, page_size].lookup_table_type`): * ​max\_seq\_length (`UInt32`): * ​max\_cache\_length (`UInt32`): * ​kv\_cache\_dynamic\_shape (`IndexList[4]`): * ​kv\_cache\_dynamic\_strides (`IndexList[4]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`KVCollectionT`](/mojo/kernels/kv_cache/types/KVCollectionT), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocks_layout` `comptime blocks_layout = Layout.row_major(PagedKVCacheCollection[dtype_, kv_params_, page_size].blocks_shape)` ### `blocks_shape` `comptime blocks_shape = IntTuple(-1, 2 if (not kv_params_.is_mla._mlir_value) else 1, -1, page_size, Int(PagedKVCacheCollection[dtype_, kv_params_, page_size].kv_params), Int(PagedKVCacheCollection[dtype_, kv_params_, page_size].kv_params))` ### `blocks_type` `comptime blocks_type = LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size].blocks_layout, MutAnyOrigin]` ### `cache_lengths_type` `comptime cache_lengths_type = LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin]` ### `CacheType` `comptime CacheType = PagedKVCache[PagedKVCacheCollection[dtype_, kv_params_, page_size].dtype, PagedKVCacheCollection[dtype_, kv_params_, page_size].kv_params, page_size]` ### `dtype` `comptime dtype = dtype_` ### `kv_params` `comptime kv_params = kv_params_` ### `lookup_table_type` `comptime lookup_table_type = LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin]` ### `name_str` `comptime name_str = "paged"` ## Methods ### `__init__` `__init__(out self, blocks: LayoutTensor[PagedKVCacheCollection[dtype_, kv_params_, page_size].dtype, Layout.row_major[6](), MutAnyOrigin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), ImmutAnyOrigin], lookup_table: LayoutTensor[DType.uint32, Layout.row_major[2](), ImmutAnyOrigin], max_seq_length: UInt32, max_cache_length: UInt32)` ### `get_key_cache` `get_key_cache(self, layer_idx: Int) -> PagedKVCacheCollection[dtype_, kv_params_, page_size].CacheType` **Returns:** `PagedKVCacheCollection` ### `get_value_cache` `get_value_cache(self, layer_idx: Int) -> PagedKVCacheCollection[dtype_, kv_params_, page_size].CacheType` **Returns:** `PagedKVCacheCollection` ### `cache_length` `cache_length(self, bs_idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## types
This module contains the types for the key-value cache APIs. The module includes structs implementing several different types of [KV caches](/glossary/ai/kv-cache). This module defines two traits that define the roles of the different structs * `KVCacheT`: Defines the interface for a single (key or value) cache. * `KVCollectionT`: Defines the interface for a pair of caches (keys and values). ## Structs * [​`ContinuousBatchingKVCache`](./ContinuousBatchingKVCache): Wrapper for the ContinuousKVCache of a given layer in the transformer model. * [​`ContinuousBatchingKVCacheCollection`](./ContinuousBatchingKVCacheCollection): This is a "view" of the cache for the given sequences in the batch. * [​`KVCacheStaticParams`](./KVCacheStaticParams): * [​`PagedKVCache`](./PagedKVCache): The PagedKVCache is a wrapper around the KVCache blocks for a given layer. It is used to access the KVCache blocks for PagedAttention. * [​`PagedKVCacheCollection`](./PagedKVCacheCollection): ## Traits * [​`KVCacheT`](./KVCacheT): Trait for different KVCache types and implementations. * [​`KVCollectionT`](./KVCollectionT): Trait for a pair of caches (keys and values).
--- ## CopyPolicy
The CopyPolicy trait defines requirements needed for a tensor to be copied. These requirements check the compatibility of the source and destination tensors. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `verify_source_tensor` `static verify_source_tensor(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` A static function that verifies the source tensor is compatible with the copy operation. If the tensor is not valid compilation will fail. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor that will be copied from. ### `verify_destination_tensor` `static verify_destination_tensor(dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` A static function that verifies the destination tensor is compatible with the copy operation. If the tensor is not valid compilation will fail. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor that will be copied to. ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** `String`: The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** `String`: The device type's name. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## copy
## Traits * [​`CopyPolicy`](./CopyPolicy): The CopyPolicy trait defines requirements needed for a tensor to be copied.
--- ## Element
`struct Element[dtype: DType, layout: Layout, /, index_type: DType = _get_index_type(layout)]` A wrapper around SIMD types that provides layout-driven vectorized operations. The `Element` struct extends SIMD types with layout-aware load and store operations, enabling efficient vectorized access to multi-dimensional data. It maps between logical tensor coordinates and physical memory locations according to the specified layout. ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout describing how elements are organized. * ​index\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer type of the index pointing to each element. ## Fields * ​element\_data (`Element[dtype, layout, index_type].element_data_type`): The actual SIMD data stored in this element. This field contains the vectorized data values that can be processed efficiently using SIMD operations. * ​runtime\_layout (`RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type]`): The runtime layout information for memory access patterns. This field stores the layout information needed to map between logical tensor coordinates and physical memory locations, supporting both compile-time and runtime-determined access patterns. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `element_data_type` `comptime element_data_type = SIMD[dtype, layout.size()]` The SIMD type used to store and process the element data. This type alias defines a SIMD vector with the specified data type and size matching the layout's total element count, enabling efficient vectorized operations. ## Methods ### `__init__` `__init__(out self, element_data: SIMD[dtype, layout.size()])` Initializes an Element with the given SIMD data. **Args:** * ​element\_data ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The SIMD data to initialize the element with. `__init__(out self, element_data: SIMD[dtype, layout.size()], runtime_layout: RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type])` Initializes an Element with the given SIMD data and runtime layout. **Args:** * ​element\_data ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The SIMD data to initialize the element with. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout to use for memory access. ### `load` `static load(ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin], runtime_layout: RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type] = RuntimeLayout[layout, DType.int32, index_type]()) -> Self` Loads data from memory according to the specified layout. This method loads data from memory using the layout information to determine the memory access pattern. It supports both rank-1 and rank-2 layouts with various stride patterns, optimizing for contiguous memory access when possible. **Args:** * ​ptr (`LegacyUnsafePointer`): Pointer to the memory location to load from. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout to use for memory access. **Returns:** `Self`: A new `Element` containing the loaded data. ### `masked_load` `static masked_load(ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin], runtime_layout: RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type] = RuntimeLayout[layout, DType.int32, index_type]()) -> Self` Loads data from memory with masking for partial loads. This method loads data from memory using the layout information, but also handles cases where the runtime dimensions are smaller than the static layout dimensions. It ensures that only valid memory locations are accessed. **Args:** * ​ptr (`LegacyUnsafePointer`): Pointer to the memory location to load from. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout to use for memory access. **Returns:** `Self`: A new `Element` containing the loaded data, with zeros in positions beyond the runtime dimensions. ### `store` `store(self, ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin])` Stores element data to memory according to the specified layout. This method performs a layout-aware store operation, writing data to memory following the access patterns defined by the layout. It optimizes memory writes based on the layout's stride patterns to maximize performance. The method handles different memory layout patterns: * For rank-1 tensors with contiguous memory (stride=1), it uses vectorized stores * For rank-2 tensors with contiguous rows or columns, it uses optimized slice-based stores * For non-contiguous memory layouts, it performs element-by-element stores Unlike `masked_store()`, this method assumes the full static dimensions will be written and does not perform runtime dimension boundary checking. Note: This method is constrained to layouts with rank <= 2. For higher-rank tensors, consider decomposing the operation. **Args:** * ​ptr (`LegacyUnsafePointer`): Mutable pointer to the memory location where data will be stored. ### `masked_store` `masked_store(self, ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, origin=origin])` Stores element data to memory with masking for partial stores. This method performs a layout-aware store operation with boundary checking. It ensures that only valid memory locations are written to when the runtime dimensions are smaller than the static layout dimensions, preventing out-of-bounds memory access. The method optimizes for different memory layouts: * For contiguous memory (stride=1), it uses vectorized stores when possible * For non-contiguous memory, it performs element-by-element stores * For all patterns, it respects runtime dimension bounds Note: This method is constrained to layouts with rank <= 2. For higher-rank tensors, consider decomposing the operation. **Args:** * ​ptr (`LegacyUnsafePointer`): Pointer to the memory location where data will be stored. ### `__str__` `__str__(self) -> String` Returns a string representation of the element. **Returns:** `String`: A string representation of the element's data. ### `write_to` `write_to(self, mut writer: T)` Writes the element to the specified writer. **Args:** * ​writer (`T`): The writer to output the element representation to.
--- ## MemoryElement
`struct MemoryElement[mut: Bool, //, dtype: DType, layout: Layout, origin: Origin[mut], /, address_space: AddressSpace, *, index_type: DType = _get_index_type(layout, address_space)]` Represents data in memory organized according to a specific layout. The `MemoryElement` struct provides a high-level interface for accessing data in memory with a specific layout. It encapsulates a pointer to the memory location and the runtime layout information needed to access the data correctly. This abstraction enables efficient memory operations that respect the underlying memory organization, supporting vectorized loads and stores while handling different memory layouts transparently. ## Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the memory element is mutable. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout describing how elements are organized. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the memory element. * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The memory address space where the data is located. * ​index\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer type of the index pointing to each memory element. ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin]`): Pointer to the memory location where the data is stored. This pointer provides access to the underlying memory with the specified address space and alignment requirements. It points to the first element of the data structure in memory. * ​runtime\_layout (`RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type]`): Runtime layout information used for memory access calculations. This field stores the runtime layout information needed to compute memory offsets for accessing elements according to the specified layout pattern. It handles both compile-time known dimensions and runtime-determined dimensions. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin], runtime_layout: RuntimeLayout[layout, element_type=DType.int32, linear_idx_type=index_type])` Initializes a `MemoryElement` with the given pointer and runtime layout. **Args:** * ​ptr (`LegacyUnsafePointer`): Pointer to the memory location of the element. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout to use for memory access. ### `load` `load(self, out result: Element[dtype, layout, index_type])` Loads data from memory according to the specified layout. This method performs a layout-aware load operation, reading data from memory following the access patterns defined by the layout. It optimizes memory reads based on the layout's stride patterns to maximize performance. The method leverages the underlying `Element.load` implementation which handles different memory layout patterns including contiguous and strided access. **Returns:** [`Element`](/mojo/kernels/layout/element/Element): An `Element` containing the loaded data organized according to the layout. ### `store` `store(self: MemoryElement[dtype, layout, mut_origin, address_space, index_type=index_type], src: Element[dtype, layout, index_type])` Stores element data to the memory location of this MemoryElement. This method performs a layout-aware store operation, writing data to memory following the access patterns defined by the layout. It optimizes memory writes based on the layout's stride patterns to maximize performance. The method delegates to the `Element.store` implementation which handles different memory layout patterns including vectorized stores for contiguous memory and element-by-element stores for non-contiguous layouts. **Args:** * ​src ([`Element`](/mojo/kernels/layout/element/Element)): The `Element` containing the data to store. ### `transfer` `transfer(self: MemoryElement[dtype, layout, mut_origin, address_space, index_type=index_type], src: MemoryElement[dtype, layout, origin, address_space, index_type=index_type])` Transfers data from another `MemoryElement` to this one. This method efficiently transfers data between memory locations with potentially different layouts and data types. It performs the following operations: 1. Loads data from the source `MemoryElement` using its layout 2. Converts the data to the destination data type if necessary 3. Stores the converted data to the destination memory location using its layout This provides a high-performance way to copy and convert data between different memory representations while respecting both source and destination memory layouts. **Args:** * ​src ([`MemoryElement`](/mojo/kernels/layout/element/MemoryElement)): The source `MemoryElement` to transfer data from.
--- ## element (Element)
Provides element-based access to memory using layout-driven vectorization. This module implements efficient memory access patterns for multi-dimensional data using the layout system. It provides abstractions for loading and storing data with specific memory layouts, enabling vectorized operations that respect the underlying memory organization. Key components: * `Element`: A wrapper around SIMD types that provides layout-driven vectorized operations * `MemoryElement`: Represents data in memory organized according to a specific layout These components enable efficient tensor operations by ensuring memory accesses follow optimal patterns defined by the layout system. ## Structs * [​`Element`](./Element): A wrapper around SIMD types that provides layout-driven vectorized operations. * [​`MemoryElement`](./MemoryElement): Represents data in memory organized according to a specific layout.
--- ## layout
Provides layout and layout tensor types, which abstract memory layout for multidimensional data. * The [`Layout`](/mojo/kernels/layout/layout/Layout) type represents a mapping between a set of logical coordinates and a linear index. It can be used, for example, to map logical tensor coordinates to a memory address, or to map GPU threads to tiles of data. * The [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) type is a high-performance tensor with explicit memory layout via a `Layout`. ## Modules * [​`copy`](./copy/): * [​`element`](./element/): Provides element-based access to memory using layout-driven vectorization. * [​`int_tuple`](./int_tuple/): Hierarchical integer tuple data structures for high-performance tensor operations. * [​`layout`](./layout/): Provides a high-performance tensor layout system for memory mapping and indexing. * [​`layout_tensor`](./layout_tensor/): Provides the `LayoutTensor` type for representing multidimensional data. * [​`math`](./math/): Implements math methods that work on layout tensors. * [​`runtime_layout`](./runtime_layout/): Provides the `RuntimeLayout` type and functions for working with it. You can use `RuntimeLayout` to define a layout where the dimensions are not known at compile time. * [​`runtime_tuple`](./runtime_tuple/): Provides the `RuntimeTuple` data structure and related utility functions for handling tuple-like data with both compile-time and runtime elements. `RuntimeTuple` is designed for high-performance tensor operations, supporting efficient manipulation of multi-dimensional data structures like shapes, indices, and coordinates. * [​`swizzle`](./swizzle/): Defines swizzle layouts for optimizing memory access patterns. * [​`tensor_core`](./tensor_core/): Tensor Core Module for High-Performance Matrix Operations * [​`tensor_core_async`](./tensor_core_async/): Tensor Core Async Module * [​`tma_async`](./tma_async/): Tensor Memory Accelerator (TMA) Asynchronous Operations Module
--- ## IntArray
`@register_passable` `struct IntArray` A memory-efficient, register-passable array of integers. `IntArray` provides a low-level implementation of a dynamically-sized integer array with direct memory management. This struct serves as the underlying storage mechanism for `IntTuple` and related data structures, optimized for high-performance tensor operations. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(size: Int = 0) -> Self` Initialize a new owned `IntArray` with the specified size. **Args:** * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of integers to allocate space for. Defaults to 0. ### `__copyinit__` `__copyinit__(existing: Self) -> Self` Initialize by copying an existing `IntArray`. For owned arrays, this performs a deep copy of the data. **Args:** * ​existing (`Self`): The source array to copy from. ### `__del__` `__del__(deinit self)` Destroy the `IntArray` and free its memory if owned. Only frees memory for owned arrays (positive \_size) to prevent double-free errors with views. ### `__getitem__` `__getitem__(self, idx: Int) -> Int` Access an element at the specified index. Note: Bounds checking is performed when assertions are enabled (e.g., -D ASSERT=all). **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Zero-based index of the element to access. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integer value at the specified index. ### `__setitem__` `__setitem__(mut self, idx: Int, value: Int)` Set the value at the specified index. Note: Bounds checking is performed when assertions are enabled (e.g., -D ASSERT=all). **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Zero-based index of the element to modify. * ​value ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer value to store at the specified index. ### `owning` `owning(self) -> Bool` Check if this `IntArray` owns its memory. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this array owns its memory (positive \_size), False if it's a view (negative \_size). ### `size` `size(self) -> Int` Get the number of elements in the array. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the array, regardless of ownership status. ### `copy_from` `copy_from(mut self, offset: Int, source: Self, size: Int)` Copy elements from another `IntArray`. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Destination offset in this array. * ​source (`Self`): Source array to copy from. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements to copy. `copy_from(mut self, dst_offset: Int, source: Self, src_offset: Int, size: Int)` Copy elements from another IntArray with source offset. **Args:** * ​dst\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Destination offset in this array. * ​source (`Self`): Source array to copy from. * ​src\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Source offset in the source array. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements to copy.
--- ## IntTuple
`struct IntTuple` A hierarchical, nested tuple of integers with efficient memory management. IntTuple provides a flexible data structure for representing multi-dimensional shapes, indices, and other nested integer collections. It supports both flat and hierarchical representations with efficient memory sharing. This structure is fundamental for tensor operations, layout specifications, and dimension handling in high-performance computing contexts. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _IntTupleIter[origin_of((muttoimm iterable_origin._mlir_origin))]` The iterator type for IntTuple iteration. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ### `MinimumValue` `comptime MinimumValue = -65534` Minimum allowed value for integers in an `IntTuple`. This constant defines the lower bound for integer values that can be stored directly in an `IntTuple`. Values below this threshold are reserved for internal use to represent structural information like sub-tuple offsets. ## Methods ### `__init__` `__init__(out self)` Initialize an empty IntTuple. Creates an `IntTuple` with zero elements, which can be used as a starting point for building tuples incrementally with `append` or `extend`. Performance: * Minimal allocation (just a single element for length). * Structure validation performed when assertions are enabled. `__init__(out self, *, num_elems: Int)` Initialize an `IntTuple` with a specified number of uninitialized elements. Creates an `IntTuple` with space for the specified number of elements, but does not initialize the elements themselves. Note: Structure validation performed when assertions are enabled. **Args:** * ​num\_elems ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to allocate space for. `__init__(out self, *elements: Int)` Initialize an `IntTuple` with a variadic list of integers. Creates an `IntTuple` containing the provided integer values. **Args:** * ​\*elements ([`Int`](/mojo/stdlib/builtin/int/Int)): Variable number of integer values to store in the tuple. `__init__(out self, elements: VariadicList[Int])` Initialize an `IntTuple` with a list of integers. Creates an `IntTuple` containing the provided integer values. Notes: * Pre-allocates exact memory needed for efficiency. * Validates that all values are above `MinimumValue`. If any value is less than `MinimumValue`, assertion fails with an error message. * Structure validation performed when assertions are enabled. **Args:** * ​elements ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of integer values to store in the tuple. `@implicit` `__init__(out self, value: Int)` Initialize an `IntTuple` with a single integer value. Creates an `IntTuple` containing a single integer element. **Args:** * ​value ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer value to store in the tuple. `__init__(out self, *elements: Self, *, __list_literal__: Tuple[] = Tuple[]())` Initialize an `IntTuple` with nested IntTuples. Creates a hierarchical `IntTuple` containing the provided `IntTuple` elements, preserving their nested structure. **Args:** * ​\*elements (`Self`): Variable number of `IntTuple` values to store in the tuple. * ​**list\_literal** ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Specifies that this constructor can be used for list literals. `__init__(out self, *, var _owned: IntArray)` Initialize an `IntTuple` taking the values of an `IntArray`. **Args:** * ​\_owned ([`IntArray`](/mojo/kernels/layout/int_tuple/IntArray)): The `IntArray` to use as storage. `__init__(out self, existing: Self, rng: _StridedRange)` Initialize an `IntTuple` as a slice of an existing `IntTuple`. Creates a new `IntTuple` containing only the elements from the existing `IntTuple` that are specified by the range. Notes: * Preserves nested structure of elements in the slice. * Structure validation performed when assertions are enabled. **Args:** * ​existing (`Self`): The source `IntTuple` to slice from. * ​rng ([`_StridedRange`](/mojo/stdlib/builtin/range/_StridedRange)): The range of indices to include in the new `IntTuple`. `__init__(out self, dimlist: DimList)` Initialize an `IntTuple` from a DimList. Creates an `IntTuple` containing the dimensions from a DimList, handling both defined and undefined dimensions appropriately. Notes: * Converts undefined dimensions to `UNKNOWN_VALUE`. * Validates that all values are above `MinimumValue`. If any value is less than `MinimumValue`, assertion fails with an error message. **Args:** * ​dimlist ([`DimList`](/mojo/kernels/buffer/dimlist/DimList)): The DimList containing dimension information. `__init__[IterableType: Iterable](out self, iterable: IterableType)` Initialize an `IntTuple` from a zip iterator. Creates an `IntTuple` by appending each element from the zip iterator. Note: This implementation is not optimized and may be improved in future versions. **Parameters:** * ​IterableType ([`Iterable`](/mojo/stdlib/iter/Iterable)): The type of the iterable. **Args:** * ​iterable (`IterableType`): An iterable containing pairs of elements to append. ### `__copyinit__` `__copyinit__(out self, existing: Self)` Initialize by copying an existing `IntTuple`. Creates a deep copy of the provided `IntTuple`, copying all its data into newly allocated memory. Note: There is a Mojo bug where this method unnecessarily propagates the origin of self to the new copy. **Args:** * ​existing (`Self`): The `IntTuple` to copy from. ### `__getitem__` `__getitem__(self, _idx: Int) -> Self` Retrieves an element at the specified index from the `IntTuple`. Supports negative indexing (e.g., `-1` for the last element). Notes: If index is out of bounds, assertion fails with an error message. **Args:** * ​\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to retrieve. **Returns:** `Self`: An `IntTuple` containing either a single value or a sub-tuple. `__getitem__(self, span: Slice) -> Self` Retrieves a slice of elements from the `IntTuple`. Creates a new `IntTuple` containing the elements specified by the slice. **Args:** * ​span ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): A slice object specifying the range of elements to retrieve. **Returns:** `Self`: A new `IntTuple` containing the specified elements. ### `__lt__` `__lt__(self, rhs: Self) -> Bool` Compare two `IntTuple`s lexicographically. This function performs element-wise comparison of two `IntTuple`s and determines if the first is lexicographically less than the second. It compares corresponding elements until it finds a pair where the elements differ. Example: ```mojo from layout.int_tuple import IntTuple var tuple1 = IntTuple(1, 2, 3) var tuple2 = IntTuple(1, 2, 4) var result = tuple1 < tuple2 # Returns True because 3 < 4 ``` **Args:** * ​rhs (`Self`): The other `IntTuple` to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is lexicographically less than `rhs`, False otherwise. ### `__eq__` `__eq__(self, other: Self) -> Bool` Equality operator for `IntTuple`. **Args:** * ​other (`Self`): The `IntTuple` to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `IntTuple`s are equal, False otherwise. ### `__ne__` `__ne__(self, other: Self) -> Bool` Inequality operator for `IntTuple`. **Args:** * ​other (`Self`): The `IntTuple` to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `IntTuple`s are not equal, False otherwise. ### `elements_size` `static elements_size(elements: VariadicListMem[IntTuple, origin, is_owned]) -> Int` Calculate the total storage size needed for a list of IntTuples. Computes the sum of sizes for all elements, accounting for both direct integer values and nested sub-tuples. **Args:** * ​elements ([`VariadicListMem`](/mojo/stdlib/builtin/variadics/VariadicListMem)): List of `IntTuple` elements to measure. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total storage size required for all elements. `static elements_size[_origin: ImmutOrigin, n: Int](elements: InlineArray[Pointer[IntTuple, _origin], n], idx: Int) -> Int` Calculate the total storage size needed for IntTuples at a specific index. Computes the sum of sizes for all elements at the given index in an array of `IntTuple` pointers. **Parameters:** * ​\_origin ([`ImmutOrigin`](/mojo/stdlib/builtin/type_aliases/#immutorigin)): Origin tracking for memory safety. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the inline array. **Args:** * ​elements ([`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray)): Array of pointers to `IntTuple`s. * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Index to access in each `IntTuple`. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total storage size required for all elements at the specified index. ### `owned_copy` `owned_copy(self) -> Self` Create a deep copy of this `IntTuple` with its own memory ownership. This method creates a completely independent copy of the `IntTuple` with newly allocated memory. Unlike `__copyinit__`, this method can be called on an existing instance to create a separate copy. Example: ```mojo from layout import IntTuple var original = IntTuple(1, 2, 3) var copy = original.owned_copy() # Modifying copy will not affect original ``` **Returns:** `Self`: A new `IntTuple` containing the same data as this one but with independent memory ownership. ### `replace_entry` `replace_entry(self, idx: Int, value: Self) -> Self` Replace an entry in the tuple with another `IntTuple`. Creates a new `IntTuple` with the element at the specified index replaced by the provided `IntTuple`. Note: If the index is out of bounds, assertion fails with an error message. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to replace. * ​value (`Self`): The `IntTuple` to insert at the specified index. **Returns:** `Self`: A new `IntTuple` with the replacement applied. `replace_entry(mut self, idx: Int, *, int_value: Int)` Replace an integer value at the specified index in-place. Directly modifies the tuple by replacing the integer value at the given index. This is more efficient than creating a new tuple when only a single value needs to be changed. Note: If the index is out of bounds, assertion fails with an error message. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to replace. * ​int\_value ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer value to insert at the specified index. ### `count_values` `count_values(self) -> Int` Count the total number of integer values in this tuple hierarchy. Recursively traverses the nested tuple structure and counts all integer values. This is useful for determining the size needed for flattened representations. Note: For a flat tuple, this will return the same value as `len(self)`. For nested tuples, it counts all leaf integer values. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total count of integer values in this tuple and all nested tuples. ### `flatten` `flatten(self) -> Self` Flatten a nested `IntTuple` into a single-level `IntTuple`. This function converts a hierarchical `IntTuple` structure into a flat sequence of integer values, preserving the order of elements. **Returns:** `Self`: A new `IntTuple` containing all integer values in a flat structure. ### `product_flatten` `product_flatten(self) -> Self` Coalesces a nested `IntTuple` into a single-level `IntTuple`, by multiplying all the values together. **Returns:** `Self`: A new `IntTuple` containing the products of each top level tuple, in a flat structure. ### `all_known` `all_known(self) -> Bool` Check if all values in this tuple hierarchy are known (not `UNKNOWN_VALUE`). Recursively traverses the nested tuple structure and checks if any value is equal to `UNKNOWN_VALUE`. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all values in this tuple and nested tuples are known, False if any value is `UNKNOWN_VALUE`. `all_known[start: Int, end: Int](self) -> Bool` Check if all values in this tuple hierarchy are known (not `UNKNOWN_VALUE`). Recursively traverses the nested tuple structure and checks if any value is equal to `UNKNOWN_VALUE`. **Parameters:** * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The starting index (inclusive) for the range to check. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The ending index (exclusive) for the range to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all values in this tuple and nested tuples are known, False if any value is `UNKNOWN_VALUE`. ### `append` `append(mut self, *elements: Self)` Append one or more `IntTuple` elements to this tuple. This method modifies the tuple in-place by adding the provided elements to the end of the tuple. It handles both value tuples and nested tuples. Notes: * This operation requires reallocating the underlying `IntArray` storage to accommodate the new elements, which may impact performance for large tuples. **Args:** * ​\*elements (`Self`): Variable number of `IntTuple` objects to append to this tuple. ### `extend` `extend(mut self, tuple: Self)` Extends this tuple by appending all elements from another tuple. This method modifies the tuple in-place by adding all elements from the provided tuple to the end of this tuple. It efficiently handles both value elements and nested tuples. Notes: * This operation requires reallocating the underlying `IntArray` storage to accommodate the new elements, which may impact performance for large tuples. * If the input tuple is empty, this method returns without making any changes. **Args:** * ​tuple (`Self`): The `IntTuple` whose elements will be appended to this tuple. ### `size` `size(self) -> Int` Returns the total size of the `IntTuple` in memory. For owning tuples, returns the size of the underlying `IntArray`. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total size in memory units. ### `tuple_size` `static tuple_size(data: IntArray) -> Int` Recursively calculates the size of a tuple represented by an `IntArray`. This method traverses the tuple structure, accounting for both direct values and nested sub-tuples to compute the total memory footprint. **Args:** * ​data ([`IntArray`](/mojo/kernels/layout/int_tuple/IntArray)): `IntArray` containing the tuple data. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total size of the tuple in memory units. ### `validate_structure` `validate_structure(self)` Validates the internal structure of the `IntTuple`. Ensures that the actual size of the underlying data matches the computed size based on the tuple's structure. This helps detect memory corruption or implementation errors. Assertion fails with an error message if validation fails. ### `__len__` `__len__(self) -> Int` Returns the number of elements in the `IntTuple`. This is the logical length of the tuple, not its memory size. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the tuple. ### `__iter__` `__iter__(ref self) -> _IntTupleIter[origin_of((muttoimm self_is_origin))]` Returns an iterator over the elements of the `IntTuple`. This enables iteration through the tuple using for-loops. **Returns:** `_IntTupleIter`: An iterator object for this `IntTuple`. ### `is_value` `is_value(self) -> Bool` Determines if this `IntTuple` represents a single value rather than a tuple. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this `IntTuple` contains exactly one element that is a value, False otherwise. `is_value(self, i: Int) -> Bool` Determines if the element at the specified index is a value rather than a tuple. Notes: If index is out of bounds, assertion fails with an error message. **Args:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the element at index i is a value, False if it's a tuple. ### `is_tuple` `is_tuple(self) -> Bool` Determines if this `IntTuple` represents a tuple rather than a single value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this `IntTuple` is a tuple (not a single value), False otherwise. `is_tuple(self, i: Int) -> Bool` Determines if the element at the specified index is a tuple rather than a value. Notes: This is the complement of is\_value(i). **Args:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the element at index i is a tuple, False if it's a value. ### `value` `value(self) -> Int` Retrieves the value of this `IntTuple` if it represents a single value. This method should only be called if `is_value()` returns True. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integer value stored in this `IntTuple`. `value(self, i: Int) -> Int` Retrieves the value of the element at the specified index. This method should only be called if `is_value(i)` returns True. Notes: If the element is not a value, the behavior is undefined. **Args:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to retrieve. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integer value stored at the specified index. ### `tuple` `tuple(ref self) -> ref [self] Self` Returns a reference to this `IntTuple` as a tuple. Notes: This method is used to access the current `IntTuple` as a tuple without creating a copy of the data. **Returns:** `ref`: A reference to this `IntTuple` to avoid unnecessary copying. ### `write_to` `write_to(self, mut writer: T)` Writes a string representation of this `IntTuple` to the provided writer. Notes: For single values, writes just the value. For tuples, writes a comma-separated list of elements enclosed in parentheses. **Args:** * ​writer (`T`): The writer to output the string representation to. ### `__str__` `__str__(self) -> String` Returns a string representation of this `IntTuple`. **Returns:** `String`: A string representation of the `IntTuple`, using the `write_to` method. ### `is_equal` `static is_equal(a, b: Self) -> Bool` Compares two `IntTuple`s for equality. Notes: Handles nested tuples and special cases where a single-element tuple is equivalent to its contained value. **Args:** * ​a (`Self`): The first `IntTuple` to compare. * ​b (`Self`): The second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `IntTuple`s are equal in structure and values, False otherwise. ### `__repr__` `__repr__(self) -> String` Returns a string representation of this `IntTuple` for debugging. **Returns:** `String`: A string representation of the `IntTuple`, same as `__str__`. ### `__int__` `__int__(self) -> Int` Converts this `IntTuple` to an integer. This method should only be called if `is_value()` returns True. Aborts: If the `IntTuple` is not a single value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integer value stored in this `IntTuple`.
--- ## abs
`abs(t: IntTuple) -> IntTuple` Compute the absolute value of each element in an `IntTuple`. This function applies the absolute value operation to each integer in a potentially nested `IntTuple` structure. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to transform. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure but with absolute values.
--- ## apply
`apply[func: fn(Int) capturing -> Int](t: IntTuple) -> IntTuple` Apply a function to each integer value in an `IntTuple`. This function recursively applies the given function to each integer value in a potentially nested `IntTuple` structure, preserving the structure. **Parameters:** * ​func (`fn(Int) capturing -> Int`): Function to apply to each integer value. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to transform. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure but with each integer value transformed by the function.
--- ## apply_predicate
`apply_predicate[predicate: fn(IntTuple, IntTuple) -> Bool](a: IntTuple, b: IntTuple) -> Bool` Apply a predicate function recursively to two `IntTuple`s. This function traverses two `IntTuple`s with the same structure and applies a predicate function to corresponding elements. The predicate is applied only to the leaf nodes (integer values). Note: If the structures of the two `IntTuple`s don't match (different nesting or length), the function returns False without applying the predicate. **Parameters:** * ​predicate (`fn(IntTuple, IntTuple) -> Bool`): A function that takes two `IntTuple`s (containing integer values) and returns a boolean result. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple` to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the predicate returns True for all corresponding elements and the structures match, False otherwise.
--- ## apply_zip
`apply_zip[func: fn(IntTuple, IntTuple) -> IntTuple](t1: IntTuple, t2: IntTuple) -> IntTuple` Apply a function to pairs of elements from two `IntTuple`s. This function zips two `IntTuple`s together and applies the given function to each pair of elements, creating a new `IntTuple` with the results. **Parameters:** * ​func (`fn(IntTuple, IntTuple) -> IntTuple`): Function that takes two `IntTuple`s and returns an `IntTuple`. **Args:** * ​t1 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​t2 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the results of applying func to each pair. `apply_zip[func: fn(IntTuple, IntTuple) capturing -> IntTuple](t1: IntTuple, t2: IntTuple) -> IntTuple` Apply a capturing function to pairs of elements from two `IntTuple`s. This overload allows the function to capture variables from its environment. **Parameters:** * ​func (`fn(IntTuple, IntTuple) capturing -> IntTuple`): Capturing function that takes two `IntTuple`s and returns an `IntTuple`. **Args:** * ​t1 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​t2 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the results of applying func to each pair. `apply_zip[func: fn(IntTuple, IntTuple, IntTuple) -> IntTuple](t1: IntTuple, t2: IntTuple, t3: IntTuple) -> IntTuple` Apply a function to triplets of elements from three `IntTuple`s. This function zips three `IntTuple`s together and applies the given function to each triplet of elements, creating a new `IntTuple` with the results. **Parameters:** * ​func (`fn(IntTuple, IntTuple, IntTuple) -> IntTuple`): Function that takes three `IntTuple`s and returns an `IntTuple`. **Args:** * ​t1 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​t2 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. * ​t3 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Third `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the results of applying func to each triplet. `apply_zip[func: fn(IntTuple, IntTuple, IntTuple) capturing -> IntTuple](t1: IntTuple, t2: IntTuple, t3: IntTuple) -> IntTuple` Apply a capturing function to triplets of elements from three `IntTuple`s. This overload allows the function to capture variables from its environment. **Parameters:** * ​func (`fn(IntTuple, IntTuple, IntTuple) capturing -> IntTuple`): Capturing function that takes three `IntTuple`s and returns an `IntTuple`. **Args:** * ​t1 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​t2 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. * ​t3 ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Third `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the results of applying func to each triplet.
--- ## compact_order
`compact_order(shape: IntTuple, order: IntTuple) -> IntTuple` Create a compact stride based on shape and order. This function generates a stride tuple where lower order numbers imply faster varying strides. The resulting shape and stride form a bijective layout. Performance: * Always inlined for optimal performance in tight loops. * Flattens inputs and re-nests results for consistent behavior. Example: ```mojo from layout import IntTuple from layout.int_tuple import compact_order # Create a compact layout with dimensions (2,3,4,5) and ordering (1,4,3,5) var x = compact_order(IntTuple(2,3,4,5), IntTuple(1,4,3,5)) # returns (1,8,2,24) # Create a compact layout with nested dimensions and corresponding ordering var y = compact_order(IntTuple(2,IntTuple(3,4),5), IntTuple(1,IntTuple(2,3),4)) # returns (1,(2,6),24) ``` **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape tuple defining dimensions. * ​order ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The order tuple defining the relative ordering of dimensions. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A stride tuple that creates a compact memory layout according to the specified order.
--- ## compatible
`compatible(a: IntTuple, b: IntTuple) -> Bool` Test if two shapes are compatible for tensor operations. This function checks if shape A is compatible with shape B, meaning: 1. The total size of A and B are the same 2. Any coordinate into A can also be used as a coordinate into B Compatible can also be thought of as a partial order on A and B: A <= B. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The first `IntTuple` to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if shape A is compatible with shape B, False otherwise.
--- ## congruent
`congruent(a: IntTuple, b: IntTuple) -> Bool` Test if two `IntTuple`s have the same hierarchical structure. This function checks if two `IntTuple`s have identical nesting patterns, regardless of the actual integer values they contain. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple` to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if both `IntTuple`s have the same hierarchical structure, False otherwise.
--- ## crd2idx
`crd2idx(crd: IntTuple, shape: IntTuple) -> Int` Map a logical coordinate to a linear index. This function converts a multi-dimensional coordinate to a linear index based on the shape. It uses default strides computed from the shape. **Args:** * ​crd ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The coordinate tuple to convert. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The linear index corresponding to the coordinate. `crd2idx(crd: IntTuple, shape: IntTuple, _stride: IntTuple) -> Int` Map a logical coordinate to a linear index with custom strides. This function converts a multi-dimensional coordinate to a linear index based on the shape and stride information. If no stride is provided, it computes default strides from the shape. The function handles various input combinations: * Tuple coordinates with tuple shapes and strides * Single integer coordinate with tuple shapes and strides * Single integer coordinate with single integer shape and stride Aborts: ``` - If coordinate and shape dimensions don't match. - If shape and stride dimensions don't match. - If input type combinations are invalid. ``` **Args:** * ​crd ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The coordinate(s) to convert, can be a single value or a tuple of coordinates. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array, can be a single value or a tuple of dimensions. * ​\_stride ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Optional custom strides, defaults to row-major strides if not provided. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The linear index corresponding to the coordinate.
--- ## depth
`depth(src: IntTuple) -> Int` Calculates the maximum nesting depth of an `IntTuple`. This function recursively traverses the `IntTuple` structure to determine its maximum nesting depth. A scalar value has depth 0, a flat tuple has depth 1, and nested tuples increase the depth accordingly. Example: ```mojo from layout import IntTuple, depth print(depth(IntTuple(1))) # prints 0 print(depth(IntTuple(1, 2))) # prints 1 print(depth((IntTuple(1, 2)))) # prints 2 ``` **Args:** * ​src ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to measure the depth of. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer representing the maximum nesting depth.
--- ## fill_like
`fill_like(src: IntTuple, val: Int) -> IntTuple` Creates an `IntTuple` with the same structure as the source but filled with a specified value. This function recursively traverses the source `IntTuple` and creates a new `IntTuple` with identical structure, but with all leaf values replaced by the specified value. **Args:** * ​src ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The source `IntTuple` whose structure will be copied. * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer value to fill the new `IntTuple` with. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure as src but filled with val.
--- ## flatten
`flatten(t: IntTuple) -> IntTuple` Flatten a nested `IntTuple` into a single-level `IntTuple`. This function converts a hierarchical `IntTuple` structure into a flat sequence of integer values, preserving the order of elements. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The nested `IntTuple` to flatten. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing all integer values in a flat structure.
--- ## idx2crd
`idx2crd(idx: IntTuple, shape: IntTuple) -> IntTuple` Converts a linear index to a coordinate tuple within a given shape. This function splits an index into a coordinate within a Shape via a colexicographical enumeration of coordinates in Shape. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The linear index to convert. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the coordinates corresponding to the linear index. `idx2crd(idx: IntTuple, shape: IntTuple, _stride: IntTuple) -> IntTuple` Converts a linear index to a coordinate tuple within a given shape using custom strides. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The linear index to convert. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array. * ​\_stride ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Custom strides to use for the conversion. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the coordinates corresponding to the linear index.
--- ## idx2crd2
`idx2crd2(idx: IntTuple, shape: IntTuple, _stride: IntTuple) -> IntTuple` Convert a linear index to coordinates. This function handles the actual conversion logic for different input combinations. Notes: * Handles four cases: tuple-tuple-tuple, tuple-int-int, int-tuple-tuple, and int-int-int. * When input shapes don't match, `abort()` will be called. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The linear index to convert. * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the tensor/array. * ​\_stride ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Custom strides to use for the conversion. If empty, strides are computed from the shape using prefix\_product. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new IntTuple containing the coordinates corresponding to the linear index.
--- ## int_tuple
Hierarchical integer tuple data structures for high-performance tensor operations. This module provides a flexible, memory-efficient implementation of nested integer tuples optimized for tensor shape, stride, and index operations in high-performance computing. The core data structures support both flat and hierarchical representations with efficient memory sharing and zero-copy views. Key components: * `IntArray`: Low-level register-passable array with direct memory management * `IntTuple`: Hierarchical nested tuple with efficient memory layout and operations * Utility functions for tensor shape manipulation, coordinate transformations, and layout operations Performance features: * Register-passable data structures for optimal compiler optimizations * Zero-copy views for efficient memory sharing * Specialized memory layout for nested structures * Optimized algorithms for common tensor operations Common operations: * Shape manipulation: `flatten`, `to_nest`, `apply`, `product`, `sum` * Coordinate transformations: `idx2crd`, `crd2idx` * Layout operations: `compact_order`, `prefix_product` * Structural comparisons: `congruent`, `compatible`, `weakly_congruent` Example usage: ```mojo from layout import IntTuple from layout.int_tuple import flatten, compact_order, size # Create nested tuples var shape = IntTuple(2, IntTuple(3, 4), 5) # Represents shape (2, (3, 4), 5) # Flatten a nested tuple var flat = flatten(shape) # Results in (2, 3, 4, 5) # Create compact strides for a given shape and order var order = IntTuple(1, IntTuple(2, 3), 4) var strides = compact_order(shape, order) # Results in (1, (2, 6), 24) # Calculate total size (product of all elements) var total_size = size(shape) # Results in 120 ``` ## `comptime` values ### `IntList` `comptime IntList = List[Int]` A type alias for a List of integers with ownership. This alias defines a List that contains Int values and has ownership of its data. It's used throughout the module for storing and manipulating collections of integers, particularly for operations like permutations and indices. ### `UNKNOWN_VALUE` `comptime UNKNOWN_VALUE = -1` Special value indicating an unknown or unspecified dimension. This constant is used throughout the `IntTuple` system to represent dimensions that are not known at compile time or have not been specified. ## Structs * [​`IntArray`](./IntArray): A memory-efficient, register-passable array of integers. * [​`IntTuple`](./IntTuple): A hierarchical, nested tuple of integers with efficient memory management. ## Functions * [​`abs`](./abs): Compute the absolute value of each element in an `IntTuple`. * [​`apply`](./apply): Apply a function to each integer value in an `IntTuple`. * [​`apply_predicate`](./apply_predicate): Apply a predicate function recursively to two `IntTuple`s. * [​`apply_zip`](./apply_zip): Apply a function to pairs of elements from two `IntTuple`s. * [​`compact_order`](./compact_order): Create a compact stride based on shape and order. * [​`compatible`](./compatible): Test if two shapes are compatible for tensor operations. * [​`congruent`](./congruent): Test if two `IntTuple`s have the same hierarchical structure. * [​`crd2idx`](./crd2idx): Map a logical coordinate to a linear index. * [​`depth`](./depth): Calculates the maximum nesting depth of an `IntTuple`. * [​`fill_like`](./fill_like): Creates an `IntTuple` with the same structure as the source but filled with a specified value. * [​`flatten`](./flatten): Flatten a nested `IntTuple` into a single-level `IntTuple`. * [​`idx2crd`](./idx2crd): Converts a linear index to a coordinate tuple within a given shape. * [​`idx2crd2`](./idx2crd2): Convert a linear index to coordinates. * [​`inner_product`](./inner_product): Compute the inner product of two `IntTuple`s. * [​`is_flat`](./is_flat): Check if an `IntTuple` is flat. * [​`is_int`](./is_int): Check if an `IntTuple` represents a single integer value. * [​`is_tuple`](./is_tuple): Check if an `IntTuple` represents a nested tuple. * [​`mul`](./mul): Multiply each element in an `IntTuple` by a scalar value. * [​`prefix_product`](./prefix_product): Compute the exclusive prefix product of an `IntTuple`. * [​`product`](./product): Calculate the product of all values in an `IntTuple`. * [​`product_each`](./product_each): Compute the product of elements in each sub-tuple of an `IntTuple`. * [​`propagate_unknown`](./propagate_unknown): Propagates unknown dimensions from the target `IntTuple` to the source `IntTuple`. * [​`reduce`](./reduce): Apply a reduction function to an `IntTuple` with an initial value. * [​`reverse`](./reverse): Reverses the order of elements in an `IntTuple`, recursively. * [​`shallow_apply`](./shallow_apply): Apply a function to each top-level element of an `IntTuple`. * [​`shape_div`](./shape_div): Performs division operation between shape tuples. * [​`signum`](./signum): Calculate the sign of an integer. * [​`size`](./size): Calculate the total size (product of all elements) of an `IntTuple`. * [​`sorted`](./sorted): Sort an IntTuple using the provided comparison function. * [​`sum`](./sum): Calculate the sum of all values in an `IntTuple`. * [​`to_index_list`](./to_index_list): Converts an IntTuple to a flattened IndexList with the same values. * [​`to_nest`](./to_nest): Nests a flat `IntTuple` according to the structure of a nested `IntTuple`. * [​`to_unknown`](./to_unknown): Create an `IntTuple` with the same structure but filled with `UNKNOWN_VALUE`. * [​`tuple_max`](./tuple_max): Calculate the maximum value in an `IntTuple`. * [​`tuple_min`](./tuple_min): Compute the element-wise minimum of two `IntTuple`s. * [​`weakly_compatible`](./weakly_compatible): Test if shape A is weakly compatible with shape B. * [​`weakly_congruent`](./weakly_congruent): Test if two IntTuples have similar hierarchical structures.
--- ## inner_product
`inner_product(a: IntTuple, b: IntTuple) -> Int` Compute the inner product of two `IntTuple`s. For flat tuples, this is the sum of element-wise products. For nested tuples, the function recurses into corresponding nested elements. Note: If the input tuples have different lengths, assertion fails. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The inner product as an `Int`.
--- ## is_flat
`is_flat(t: IntTuple) -> Bool` Check if an `IntTuple` is flat. This function checks if the `IntTuple` is flat, meaning it has no nested elements. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `IntTuple` is flat, False otherwise.
--- ## is_int
`is_int(t: IntTuple) -> Bool` Check if an `IntTuple` represents a single integer value. This function determines whether the given `IntTuple` contains a single integer value rather than a nested tuple structure. Example: ```mojo from layout.int_tuple import is_int, IntTuple var single_value = IntTuple(5) var nested_tuple = IntTuple(1, 2, 3) var result1 = is_int(single_value) # Returns True var result2 = is_int(nested_tuple) # Returns False ``` **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `IntTuple` contains a single integer value, False if it's a nested tuple.
--- ## is_tuple
`is_tuple(t: IntTuple) -> Bool` Check if an `IntTuple` represents a nested tuple. This function determines whether the given `IntTuple` contains nested elements rather than a single integer value. It is the complement of the `is_int` function. Example: ```mojo from layout.int_tuple import is_tuple, IntTuple var single_value = IntTuple(5) var nested_tuple = IntTuple(1, 2, 3) var result1 = is_tuple(single_value) # Returns False var result2 = is_tuple(nested_tuple) # Returns True ``` **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `IntTuple` contains nested elements, False if it's a single integer value.
--- ## mul
`mul(lhs: IntTuple, rhs: Int) -> IntTuple` Multiply each element in an `IntTuple` by a scalar value. This function creates a new `IntTuple` where each element (at any nesting level) is multiplied by the provided integer value. **Args:** * ​lhs ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` whose elements will be multiplied. * ​rhs ([`Int`](/mojo/stdlib/builtin/int/Int)): The scalar integer to multiply each element by. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure as the input but with all elements multiplied by the scalar value.
--- ## prefix_product
`prefix_product(a: IntTuple) -> IntTuple` Compute the exclusive prefix product of an `IntTuple`. This is a convenience wrapper that initializes the prefix product with 1. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The input `IntTuple` to compute the prefix product for. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the exclusive prefix product of the input. `prefix_product(a: IntTuple, init: Int) -> IntTuple` Compute the exclusive prefix product of an `IntTuple` with an initial value. This function delegates to the implementation in prefix\_product2. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The input `IntTuple` to compute the prefix product for. * ​init ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial value(s) for the prefix product, defaults to 1. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the exclusive prefix product of the input.
--- ## product
`product(t: IntTuple) -> Int` Calculate the product of all values in an `IntTuple`. This function recursively computes the product of all integer values in a potentially nested `IntTuple` structure. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to multiply. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The product of all integer values, or `UNKNOWN_VALUE` if any value in the tuple is `UNKNOWN_VALUE`.
--- ## product_each
`product_each(t: IntTuple) -> IntTuple` Compute the product of elements in each sub-tuple of an `IntTuple`. For each immediate child of the input tuple, this function computes the product of all elements within that child. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` containing sub-tuples. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` where each element is the product of the corresponding sub-tuple in the input.
--- ## propagate_unknown
`propagate_unknown(src: IntTuple, target: IntTuple) -> IntTuple` Propagates unknown dimensions from the target `IntTuple` to the source `IntTuple`. This function creates a new `IntTuple` by combining the source and target `IntTuple`s, preserving unknown dimensions (UNKNOWN\_VALUE) from the target while using values from the source for known dimensions. **Args:** * ​src ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The source `IntTuple` containing known dimension values. * ​target ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The target `IntTuple` that may contain unknown dimensions (UNKNOWN\_VALUE). **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with unknown dimensions from target and known dimensions from src.
--- ## reduce
`reduce[reducer: fn(a: Int, b: IntTuple) capturing -> Int](t: IntTuple, initializer: Int) -> Int` Apply a reduction function to an `IntTuple` with an initial value. This function iterates through each element of the `IntTuple` and applies the provided reduction function cumulatively, starting with the initializer. **Parameters:** * ​reducer (`fn(a: Int, b: IntTuple) capturing -> Int`): A function that combines the accumulated result with the next element. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to reduce. * ​initializer ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial value for the reduction operation. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The final accumulated result after applying the reduction function to all elements in the `IntTuple`.
--- ## reverse
`reverse(src: IntTuple) -> IntTuple` Reverses the order of elements in an `IntTuple`, recursively. This function reverses the top-level elements of the `IntTuple` and recursively reverses any nested `IntTuple`s. Example: ```mojo from layout.int_tuple import IntTuple, reverse var t = IntTuple(1, 2, IntTuple(3, 4)) var reversed = reverse(t) # returns ((4, 3), 2, 1) ``` **Args:** * ​src ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The source `IntTuple` to reverse. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with elements in reversed order.
--- ## shallow_apply
`shallow_apply[func: fn(IntTuple) -> Int](t: IntTuple) -> IntTuple` Apply a function to each top-level element of an `IntTuple`. Unlike `apply()`, this function only operates on the immediate children of the input tuple without recursing into nested tuples. **Parameters:** * ​func (`fn(IntTuple) -> Int`): Function that takes an `IntTuple` and returns an `Int`. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` whose elements will be transformed. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the function applied to each top-level element.
--- ## shape_div
`shape_div[check: Bool = False](a: IntTuple, b: IntTuple) -> IntTuple` Performs division operation between shape tuples. Handles four cases: 1. tuple-tuple: Performs shape\_div element-wise when dimensions match 2. tuple-int: Folds the division of b across each element of a Example: `shape_div((4,5,6),40)` -> `shape_div((1,5,6),10)` -> `shape_div((1,1,6),2)` -> `(1,1,3)` 3. int-tuple: Returns `shape_div(a, product(b))` 4. int-int: Enforces the divisibility condition `a % b == 0 || b % a == 0` when possible Returns `a / b` with rounding away from `0` (that is, `1` or `-1` when `a < b`) Notes: * When tuple sizes don't match in the tuple-tuple case, `abort()` will be called. * When values are incompatible (neither divides the other) in the int-int case, `abort()` will be called. **Parameters:** * ​check ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to check for incompatible shapes. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The dividend `IntTuple`. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The divisor `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the result of the division operation
--- ## signum
`signum(a: Int) -> Int` Calculate the sign of an integer. This function determines the sign of the input integer and returns a corresponding indicator value. Example: ```mojo from layout.int_tuple import signum var result1 = signum(5) # Returns 1 var result2 = signum(-10) # Returns -1 var result3 = signum(0) # Returns 0 ``` **Args:** * ​a ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer value to determine the sign of. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): 1 if `a` > 0, -1 if `a` < 0, 0 if `a` == 0.
--- ## size
`size(a: IntTuple) -> Int` Calculate the total size (product of all elements) of an `IntTuple`. This function computes the product of all integer values in the `IntTuple`, regardless of nesting level. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` whose elements will be multiplied together. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The product of all elements in the `IntTuple`.
--- ## sorted
`sorted[cmp: fn(IntTuple, IntTuple) -> Bool = __lt__](tuple: IntTuple) -> IntTuple` Sort an IntTuple using the provided comparison function. This function implements a merge sort algorithm to efficiently sort the elements of an IntTuple. The sorting is stable and has `O(n log n)` time complexity. **Parameters:** * ​cmp (`fn(IntTuple, IntTuple) -> Bool`): A comparison function that takes two `IntTuple` elements and returns True if the first should come before the second. Defaults to the `lt` function which performs lexicographical ordering. **Args:** * ​tuple ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to be sorted. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing the same elements as the input but sorted according to the comparison function.
--- ## sum
`sum(t: IntTuple) -> Int` Calculate the sum of all values in an `IntTuple`. This function recursively computes the sum of all integer values in a potentially nested `IntTuple` structure. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to sum. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The sum of all integer values, or `UNKNOWN_VALUE` if any value in the tuple is `UNKNOWN_VALUE`.
--- ## to_index_list
`to_index_list[rank: Int, element_type: DType = DType.int64](t: IntTuple) -> IndexList[rank, element_type=element_type]` Converts an IntTuple to a flattened IndexList with the same values. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the resulting IndexList. * ​element\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element type, must be integer type. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` defining the values. **Returns:** `IndexList`: An IndexList filled with the values of t.
--- ## to_nest
`to_nest(nested: IntTuple, flat: IntTuple) -> IntTuple` Nests a flat `IntTuple` according to the structure of a nested `IntTuple`. This function reshapes a flat sequence of values into a hierarchical structure that matches the pattern of a template nested `IntTuple`. Example: ```mojo from layout import IntTuple from layout.int_tuple import to_nest var result = to_nest(IntTuple(2, IntTuple(3, 4), 5), IntTuple(1, 2, 3, 4)) # returns IntTuple(1, (2, 3), 4) ``` **Args:** * ​nested ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The template `IntTuple` defining the desired structure. * ​flat ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The flat `IntTuple` containing the values to be nested. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the values from flat arranged in the structure of nested.
--- ## to_unknown
`to_unknown(t: IntTuple) -> IntTuple` Create an `IntTuple` with the same structure but filled with `UNKNOWN_VALUE`. This function preserves the hierarchical structure of the input `IntTuple` but replaces all integer values with `UNKNOWN_VALUE`. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The template `IntTuple` defining the structure. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with the same structure as t but with all values replaced by `UNKNOWN_VALUE`.
--- ## tuple_max
`tuple_max(t: IntTuple) -> Int` Calculate the maximum value in an `IntTuple`. This function recursively finds the maximum integer value in a potentially nested `IntTuple` structure. **Args:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` to search. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The maximum integer value found in the tuple.
--- ## tuple_min
`tuple_min(a: IntTuple, b: IntTuple) -> IntTuple` Compute the element-wise minimum of two `IntTuple`s. This function compares corresponding elements of two `IntTuple`s and returns a new `IntTuple` containing the minimum value at each position. Aborts: If the input tuples have different lengths. Note: If either input contains `UNKNOWN_VALUE`, the result will be `UNKNOWN_VALUE`. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First `IntTuple`. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second `IntTuple`. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` with each element being the minimum of the corresponding elements in a and b.
--- ## weakly_compatible
`weakly_compatible(a: IntTuple, b: IntTuple) -> Bool` Test if shape A is weakly compatible with shape B. A shape A is weakly compatible with shape B if there exists a shape C congruent to A such that compatible(elem\_scale(A,C), B). This establishes a partial order relation between shapes where A <= B. Specifically, this checks if the size of B is divisible by the size of A, which is a necessary condition for weak compatibility. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The first `IntTuple` to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The second `IntTuple` to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if shape A is weakly compatible with shape B, False otherwise.
--- ## weakly_congruent
`weakly_congruent(a: IntTuple, b: IntTuple) -> Bool` Test if two IntTuples have similar hierarchical structures. This function establishes a partial order relation between IntTuples based on their hierarchical structure. It's less strict than congruent. **Args:** * ​a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): First IntTuple to compare. * ​b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Second IntTuple to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if a's structure is compatible with b's structure, False otherwise.
--- ## Layout (Layout)
`struct Layout` Represents a memory layout for multi-dimensional data. The Layout struct is the primary implementation of the LayoutTrait, providing a concrete representation of memory layouts using shape and stride information. It maps between logical coordinates and linear memory indices, enabling efficient access to multi-dimensional data. A Layout consists of: * shape: Defines the dimensions of the logical coordinate space * stride: Defines the step sizes in memory for each dimension The Layout struct supports various operations including: * Creation of row-major and column-major layouts * Conversion between coordinates and indices * Composition with other layouts * Iteration over sub-layouts Layouts can be hierarchical, with nested shapes and strides, allowing for complex memory access patterns like blocked or tiled layouts. ## Fields * ​shape (`IntTuple`): The dimensions of the layout. This field defines the size of each dimension in the logical coordinate space. For example, a shape of (3, 4) represents a 3x4 grid of elements. * ​stride (`IntTuple`): The memory step sizes for each dimension. This field defines how many elements to skip in memory when moving one unit in each dimension. For example, in a row-major 3x4 layout, the strides might be (4, 1), meaning moving one unit in the first dimension requires skipping 4 elements in memory, while moving one unit in the second dimension requires skipping 1 element. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `has_shape` `comptime has_shape = True` Indicates whether the layout has a valid shape. ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _LayoutIter[origin_of((muttoimm iterable_origin._mlir_origin))]` The iterator type for Layout iteration. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `__init__(out self)` Initializes an empty layout with no dimensions. Creates a layout with empty shape and stride tuples, which can be populated later using append operations. `__init__(out self, shape: IntTuple)` Initializes a layout with the given shape and column-major strides. Creates a layout with the specified shape and automatically calculates column-major strides (where the first dimension varies fastest in memory). **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The dimensions of the layout. `__init__(out self, shape: IntTuple, stride: IntTuple)` Initializes a layout with the given shape and stride. Creates a layout with explicitly specified shape and stride values. If an empty stride is provided, column-major strides are calculated. **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The dimensions of the layout. * ​stride ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The memory step size for each dimension, or empty for column-major. ### `__getitem__` `__getitem__(self, index: Int) -> Self` Returns a sub-layout for the specified dimension. **Args:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension index to extract. **Returns:** `Self`: A Layout containing the shape and stride for the specified dimension. ### `__eq__` `__eq__(self, other: Self) -> Bool` Checks if this layout is equal to another layout. Two layouts are considered equal if they have identical shape and stride tuples. **Args:** * ​other (`Self`): The layout to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the layouts are equal, False otherwise. ### `idx2crd` `idx2crd(self, idx: IntTuple) -> IntTuple` Converts a linear index to logical coordinates. This is the inverse operation of the **call** method, mapping from a memory index back to the corresponding logical coordinates. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The linear index to convert. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): The logical coordinates corresponding to the given index. ### `col_major` `static col_major(*dims: Int) -> Self` Creates a column-major layout with the specified dimensions. In a column-major layout, the first dimension varies fastest in memory, which is the default layout in languages like Fortran and MATLAB. Example: ```mojo from layout import Layout # Create a 3x4 column-major layout var layout = Layout.col_major(3, 4) # Result: Layout with shape (3,4) and stride (1,3) ``` **Args:** * ​\*dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Variable number of dimension sizes. **Returns:** `Self`: A column-major Layout with the specified dimensions `static col_major(shape: IntTuple) -> Self` Creates a column-major layout with the specified shape. In a column-major layout, the first dimension varies fastest in memory, which is the default layout in languages like Fortran and MATLAB. Example: ```mojo from layout import Layout from layout.int_tuple import IntTuple # Create a 3x4 column-major layout var layout = Layout.col_major(IntTuple(3, 4)) # Result: Layout with shape (3,4) and stride (1,3) ``` **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): An IntTuple specifying the dimensions. **Returns:** `Self`: A column-major Layout with the specified shape `static col_major[rank: Int](dims: DimList) -> Self` Creates a col-major layout from a DimList with compile-time rank. This method creates a col-major layout where the first dimension varies fastest in memory. It handles both known and unknown dimensions at compile time, properly calculating strides for each dimension. If any dimension is unknown, subsequent strides will also be marked as unknown. Example: ```mojo from layout import Layout from layout.layout import DimList # Create a col-major layout with compile-time rank var dims = DimList(3, 4) var layout = Layout.col_major[2](dims) # Result: Layout with shape (3,4) and stride (1,3) ``` **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Args:** * ​dims ([`DimList`](/mojo/kernels/buffer/dimlist/DimList)): A DimList containing the dimensions of the layout. **Returns:** `Self`: A col-major Layout with the specified dimensions and computed strides. `static col_major[rank: Int](tuple: IndexList[rank]) -> Self` Creates a col-major layout from a IndexList with compile-time rank. This method creates a col-major layout where the first dimension varies fastest in memory. It handles both known and unknown dimensions at compile time, properly calculating strides for each dimension. If any dimension is unknown, subsequent strides will also be marked as unknown. Example: ```mojo from layout import Layout from utils import IndexList # Create a row-major layout with compile-time rank var idx_list = IndexList[2](3, 4) var layout = Layout.col_major[2](idx_list) # Result: Layout with shape (3,4) and stride (1,3) ``` **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Args:** * ​tuple ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): An IndexList containing the dimensions of the layout. **Returns:** `Self`: A col-major Layout with the specified dimensions and computed strides. ### `row_major` `static row_major(*dims: Int) -> Self` Creates a row-major layout with the specified dimensions. In a row-major layout, the last dimension varies fastest in memory, which is the default layout in languages like C, C++, and Python. Example: ```mojo from layout import Layout # Create a 3x4 row-major layout var layout = Layout.row_major(3, 4) # Result: Layout with shape (3,4) and stride (4,1) ``` **Args:** * ​\*dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Variable number of dimension sizes. **Returns:** `Self`: A row-major Layout with the specified dimensions `static row_major[rank: Int](dims: DimList) -> Self` Creates a row-major layout from a DimList with compile-time rank. This method creates a row-major layout where the last dimension varies fastest in memory. It handles both known and unknown dimensions at compile time, properly calculating strides for each dimension. If any dimension is unknown, subsequent strides will also be marked as unknown. Example: ```mojo from layout import Layout from layout.layout import DimList # Create a row-major layout with compile-time rank var dims = DimList(3, 4) var layout = Layout.row_major[2](dims) # Result: Layout with shape (3,4) and stride (4,1) ``` **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Args:** * ​dims ([`DimList`](/mojo/kernels/buffer/dimlist/DimList)): A DimList containing the dimensions of the layout. **Returns:** `Self`: A row-major Layout with the specified dimensions and computed strides. `static row_major[rank: Int](tuple: IndexList[rank]) -> Self` Creates a row-major layout from a IndexList with compile-time rank. This method creates a row-major layout where the last dimension varies fastest in memory. It handles both known and unknown dimensions at compile time, properly calculating strides for each dimension. If any dimension is unknown, subsequent strides will also be marked as unknown. Example: ```mojo from layout import Layout from utils import IndexList # Create a row-major layout with compile-time rank var idx_list = IndexList[2](3, 4) var layout = Layout.row_major[2](idx_list) # Result: Layout with shape (3,4) and stride (4,1) ``` **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Args:** * ​tuple ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): An IndexList containing the dimensions of the layout. **Returns:** `Self`: A row-major Layout with the specified dimensions and computed strides. `static row_major[rank: Int]() -> Self` Creates a row-major layout with unknown values for each axis from a compile-time rank. Example: ```mojo from layout import Layout var layout = Layout.row_major[2]() # Result: Layout with shape (UNKNOWN_VALUE, UNKNOWN_VALUE) ``` **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The compile-time rank (number of dimensions) of the layout. **Returns:** `Self`: A row-major Layout with the given rank. `static row_major(shape: IntTuple) -> Self` Creates a row-major layout from an IntTuple of dimensions. In a row-major layout, the last dimension varies fastest in memory. This method computes the appropriate strides for a row-major layout given the input shape. Example: ```mojo from layout import Layout from layout.int_tuple import IntTuple # Create a row-major layout from a shape tuple var shape = IntTuple(3, 4) var layout = Layout.row_major(shape) # Result: Layout with shape (3,4) and stride (4,1) ``` **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): An IntTuple containing the dimensions of the layout. **Returns:** `Self`: A row-major Layout with the specified shape and computed strides. ### `make_shape_unknown` `make_shape_unknown[axis: Int = -1](self) -> Self` Creates a new Layout with unknown shape dimensions. This method creates a copy of the current Layout but marks either all dimensions or a specific dimension as unknown, while preserving the original strides. This is useful for tiling tensors with runtime sizes where the tile's shape is unknown but the memory layout (strides) remains constant. Example: ```mojo from layout import Layout from layout.int_tuple import IntTuple # Mark all dimensions as unknown var layout = Layout(IntTuple(2, 3)) var unknown = layout.make_shape_unknown() # Result: Layout with shape (?, ?) and original strides # Mark only first dimension as unknown var partial = layout.make_shape_unknown[0]() # Result: Layout with shape (?, 3) and original strides ``` **Parameters:** * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension to mark as unknown. If UNKNOWN\_VALUE (default), all dimensions are marked as unknown. **Returns:** `Self`: A new Layout with the specified dimension(s) marked as unknown and original strides preserved. ### `__str__` `__str__(self) -> String` Converts the layout to a string representation. **Returns:** `String`: A string representation of the layout in the format "(shape:stride)". ### `write_to` `write_to(self, mut writer: T)` Writes the layout to the specified writer. Formats the layout as "(shape:stride)" and writes it to the provided writer. **Args:** * ​writer (`T`): The writer to output the layout representation to. ### `__len__` `__len__(self) -> Int` Returns the number of dimensions in the layout. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the shape tuple. ### `__iter__` `__iter__(ref self) -> _LayoutIter[origin_of((muttoimm self_is_origin))]` Returns an iterator over the layout's dimensions. Each iteration yields a Layout containing the shape and stride for one dimension. **Returns:** `_LayoutIter`: An iterator over the layout's dimensions. ### `size` `size(self) -> Int` Returns the total number of elements in the layout's domain. Calculates the product of all dimensions in the shape. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total number of elements in the layout. ### `cosize` `cosize(self) -> Int` Returns the size of the memory region spanned by the layout. Calculates the maximum memory index plus one, representing the total memory footprint required by the layout. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the memory region required by the layout. ### `rank` `rank(self) -> Int` Returns the number of dimensions in the layout. This is equivalent to **len** and returns the number of elements in the shape tuple. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of dimensions in the layout. ### `__call__` `__call__(self, idx: IntTuple) -> Int` Maps logical coordinates to a linear memory index. This is the core functionality of a layout, converting multi-dimensional coordinates to a linear memory location. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The logical coordinates to map. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The linear memory index corresponding to the given coordinates. ### `append` `append(mut self, item: Self)` Appends another layout to this layout. This method adds the shape and stride from the provided layout to this layout, effectively increasing its dimensionality. **Args:** * ​item (`Self`): The layout to append to this layout. ### `all_dims_known` `all_dims_known(self) -> Bool` Checks if all dimensions in the layout have known values. A dimension is considered unknown if its shape or stride is set to the special `UNKNOWN_VALUE` constant. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all dimensions have known shape and stride values, False otherwise. ### `known_shape` `known_shape(self) -> Bool` Checks if all shape dimensions in the layout have known values. A dimension is considered unknown if its shape is set to the special `UNKNOWN_VALUE` constant. This method only checks shapes, not strides. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all shape dimensions have known values, False otherwise. ### `transpose` `transpose(self) -> Self` Transposes the layout by reversing the order of dimensions. For an n-dimensional layout, this reverses the order of both shapes and strides. For nested layouts, only the top-level dimensions are transposed, not the hierarchical structure within nested tuples. Example: ```mojo from layout import Layout from layout.int_tuple import IntTuple # Simple 2D transpose (row-major to column-major) var layout = Layout.row_major(3, 4) # shape (3,4), stride (4,1) var transposed = layout.transpose() # shape (4,3), stride (1,4) # 3D transpose var layout3d = Layout.row_major(2, 3, 4) # shape (2,3,4), stride (12,4,1) var trans3d = layout3d.transpose() # shape (4,3,2), stride (1,4,12) # Nested layout - only top level transposed var nested = Layout( IntTuple(IntTuple(2, 3), 4), IntTuple(IntTuple(12, 4), 1) ) var trans_nested = nested.transpose() # Result: shape (4, (2,3)), stride (1, (12,4)) ``` **Returns:** `Self`: A new Layout with transposed dimensions.
--- ## LayoutTrait
Defines the interface for mapping between logical coordinates and memory indices. The `LayoutTrait` provides a common interface for all layout types, including basic layouts, swizzles, and composed layouts. It enables mapping from multi-dimensional logical coordinates to linear memory indices, which is essential for tensor operations. Implementations of this trait must provide methods for: 1. Mapping coordinates to indices via the `__call__` method 2. Calculating the total size of the layout's domain 3. Calculating the size of the layout's codomain (memory footprint) 4. Indicating whether the layout has a valid shape This trait serves as the foundation for the layout system, allowing different layout implementations to be used interchangeably in algorithms. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `has_shape` `comptime has_shape` Indicates whether the layout has a valid shape. Layouts and ComposedLayouts with at least one Layout have valid shapes and can be used in layout algebra. Swizzles don't have shapes and should be excluded from layout algebra. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `__call__` `__call__(self: _Self, index: IntTuple) -> Int` Maps a logical coordinate to a linear memory index. **Args:** * ​index ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): An IntTuple representing the logical coordinates to map. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The linear memory index corresponding to the given coordinates. ### `size` `size(self: _Self) -> Int` Returns the total number of elements in the layout's domain. For a layout with shape (m, n), this returns m \* n, representing the total number of valid coordinates in the layout. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total number of elements in the layout. ### `cosize` `cosize(self: _Self) -> Int` Returns the size of the memory region spanned by the layout. For a layout with shape `(m, n)` and stride `(r, s)`, this returns `(m-1)*r + (n-1)*s + 1`, representing the memory footprint. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the memory region required by the layout. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## MakeLayoutList
`MakeLayoutList(v0: Layout, v1: Layout) -> LayoutList` Creates a list containing two layouts. This is a convenience function for creating a LayoutList with two elements. **Args:** * ​v0 ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout to include in the list. * ​v1 ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout to include in the list. **Returns:** `LayoutList`: A LayoutList containing the two provided layouts.
--- ## MakeTileLayoutList
`MakeTileLayoutList[*tile_sizes: Int]() -> LayoutList` Creates a list of layouts for tiling operations. This function creates a list of simple layouts, each with a shape from the provided tile\_sizes and a stride of 1. These layouts can be used for tiling operations. **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/stdlib/builtin/int/Int)): Variable number of integer tile dimensions. **Returns:** `LayoutList`: A LayoutList containing layouts for each tile size.
--- ## apply_tiler
`apply_tiler[func: fn(Layout, Layout) -> Layout](layout_a: Layout, tiler: List[Layout]) -> Layout` Applies a layout transformation function to each element of a layout with a tiler. This utility function applies the specified transformation function to each corresponding pair of elements from the layout and tiler list. It's a generic mechanism for implementing various tiling operations. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import apply_tiler, logical_divide # Apply logical_divide to each element of a layout with a tiler var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2), IntTuple(1, 2))) var result = apply_tiler[logical_divide](base, tilers) ``` **Parameters:** * ​func (`fn(Layout, Layout) -> Layout`): A function that takes two layouts and returns a transformed layout. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to transform. * ​tiler ([`List`](/mojo/stdlib/collections/list/List)): A list of layouts to use in the transformation. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout resulting from applying the transformation function to each pair.
--- ## blocked_product
`blocked_product(layout_a: Layout, layout_b: Layout, coalesce_output: Bool = False) -> Layout` Creates a blocked layout by combining two layouts. This function creates a hierarchical blocked layout by combining a base layout with a block layout. The result is a layout where each element of the base layout is replaced by a block defined by the second layout. This is particularly useful for creating tiled layouts for efficient cache utilization in tensor operations like matrix multiplication. Example: ```mojo from layout import Layout from layout.layout import blocked_product # Create a 2x3 matrix layout var matrix = Layout.row_major(2, 3) # Define 2x2 blocks var block = Layout.row_major(2, 2) # Create a blocked layout with 2x2 blocks var blocked = blocked_product(block, matrix) ``` Output: ```plaintext (((2, 2), (2, 3)):((2, 12), (1, 4))) 0 1 2 3 4 5 +----+----+----+----+----+----+ 0 | 0 | 1 | 4 | 5 | 8 | 9 | +----+----+----+----+----+----+ 1 | 2 | 3 | 6 | 7 | 10 | 11 | +----+----+----+----+----+----+ 2 | 12 | 13 | 16 | 17 | 20 | 21 | +----+----+----+----+----+----+ 3 | 14 | 15 | 18 | 19 | 22 | 23 | +----+----+----+----+----+----+ ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to be blocked. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The block layout defining the structure within each block. * ​coalesce\_output ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to coalesce the output layout. Default is False. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the blocked structure
--- ## coalesce
`coalesce(layout: Layout, keep_rank: Bool = False) -> Layout` Simplifies a layout by combining dimensions with contiguous strides. This function reduces the rank of a layout by merging dimensions that have contiguous memory layouts, resulting in a simpler but equivalent layout. Example: ```mojo from layout import Layout, IntTuple from layout.layout import coalesce # A layout with shape (2, (1, 4)) and stride (1, (4, 2)) can be coalesced var layout = Layout(IntTuple(2, IntTuple(1, 4)), IntTuple(1, IntTuple(4, 2))) var coalesced = coalesce(layout) # Result: Layout with shape (8) and stride (1) ``` **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to coalesce. * ​keep\_rank ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, maintains the original rank of the layout. Default is False. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A simplified layout with reduced rank where possible.
--- ## complement
`complement(layout: Layout, size: Int = 1) -> Layout` Computes the complement layout for a given layout. This function creates a layout that represents the "gaps" or complementary structure of the input layout. It's useful for creating hierarchical layouts where you need to fill in the spaces between existing layout elements. Example: ```mojo from layout import Layout, IntTuple from layout.layout import complement # Compute the complement of a layout var base = Layout(IntTuple(2, 3), IntTuple(3, 1)) var comp = complement(base, 10) # Result: A layout that fills the gaps in the original layout ``` **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The input layout to compute the complement for. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The total size of the memory region to consider. Defaults to 1. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the complement of the input layout.
--- ## composition
`composition(layout_a: Layout, layout_b: Layout) -> Layout` Composes two layouts to create a new layout. This function creates a new layout by composing two layouts, where the first layout defines the outer structure and the second layout defines the inner structure. The new layout is compatible with `layout_b` (that is, it has the same `size` and every set of coordinates in `layout_b` has an equivalent in the new layout). You can think of `layout_b` as selecting a subset of elements from `layout_a`. Example: ```mojo from layout.layout import Layout, IntTuple from layout.layout import composition # Compose a row-major layout with a tiling layout var base = Layout.row_major(6, 8) var tiling = Layout(IntTuple(3, 2), IntTuple(1, 3)) var composed = composition(base, tiling) # Result: A layout that represents a 3x2 tile from # layout_a ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The outer layout. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The inner layout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the composition of the two layouts. `composition(layout_a: Layout, tiler: List[Layout]) -> Layout` Composes a layout with a list of layouts to create a hierarchical layout. This function creates a new layout by composing each element of the first layout with the corresponding element in the tiler list. If the tiler list is shorter than the layout, the remaining elements from the layout are appended unchanged. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import composition # Compose a layout with a list of tiling layouts var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2), IntTuple(1, 2))) tilers.append(Layout(IntTuple(3, 3), IntTuple(1, 3))) var composed = composition(base, tilers) # Result: A layout with hierarchical tiling based on the tiler list ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to compose with the tiler. * ​tiler ([`List`](/mojo/stdlib/collections/list/List)): A list of layouts to compose with the base layout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the composition of the base layout with the tiler.
--- ## cosize
`cosize(l: Layout) -> Int` Returns the size of the memory region spanned by the layout. This is a standalone function equivalent to the Layout.cosize() method. **Args:** * ​l ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to calculate the cosize for. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the memory region required by the layout.
--- ## downcast
`downcast(layout: Layout, factor: Int) -> Layout` Splits elements in a layout to create a finer layout without changing the total number of elements so that the alignment is preserved. This function is useful for converting between different data type granularities, such as from uint128 to bf16. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to downcast. * ​factor ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to split into. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout with adjusted shape and stride for the finer granularity.
--- ## expand_modes_alike
`expand_modes_alike(shape_a: IntTuple, stride_a: IntTuple, shape_b: IntTuple, stride_b: IntTuple) -> InlineArray[IntTuple, 3]` Aligns two shape-stride pairs to have the same hierarchical structure. This function is used to make two layouts compatible for operations by ensuring they have the same hierarchical structure, expanding scalar values into tuples as needed. **Args:** * ​shape\_a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The first shape tuple. * ​stride\_a ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The first stride tuple. * ​shape\_b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The second shape tuple. * ​stride\_b ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The second stride tuple. **Returns:** [`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray): An array containing three tuples: the common shape, the expanded stride\_a, and the expanded stride\_b. `expand_modes_alike(layout_a: Layout, layout_b: Layout) -> InlineArray[Layout, 2]` Aligns two layouts to have the same hierarchical structure. This function tiles both layouts so they mirror each other's structure, making them compatible for operations that require matching hierarchies. Example: Given layouts with different structures: * layout\_0: (((3, (5, 2)), 4):((1, (24, 12)), 3)) * layout\_1: ((30, (2, 2)):(2, (60, 1))) The result would be two layouts with matching structures: * (((3, (5, 2)), (2, 2)):((1, (24, 12)), (3, 6))) * (((3, (5, 2)), (2, 2)):((2, (6, 30)), (60, 1))) ```mojo from layout import Layout, IntTuple from layout.layout import expand_modes_alike comptime layout_0 = Layout( IntTuple(IntTuple(3, IntTuple(5, 2)), 4), IntTuple(IntTuple(1, IntTuple(24, 12)), 3), ) comptime layout_1 = Layout( IntTuple(30, IntTuple(2, 2)), IntTuple(2, IntTuple(60, 1)) ) comptime uc = expand_modes_alike(layout_0, layout_1) print(uc[0]) # (((3, (5, 2)), (2, 2)):((1, (24, 12)), (3, 6))) print(uc[1]) # (((3, (5, 2)), (2, 2)):((2, (6, 30)), (60, 1))) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout to align. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout to align. **Returns:** [`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray): An array containing two layouts with matching hierarchical structures.
--- ## expand_strides
`expand_strides(shape: IntTuple, stride: Int) -> IntTuple` Expands a scalar stride into a stride tuple matching a shape tuple. This function creates a stride tuple that matches the structure of a shape tuple, with each stride value calculated based on the cumulative product of shape dimensions. **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape tuple to match. * ​stride ([`Int`](/mojo/stdlib/builtin/int/Int)): The base stride value to expand. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A stride tuple matching the structure of the shape tuple.
--- ## format_layout
`format_layout[W: Writer](layout: Layout, mut writer: W)` Formats a 2D layout as a table and writes it to the specified writer. This function creates a visual representation of a 2D layout as a table showing the memory indices for each logical coordinate. **Parameters:** * ​W ([`Writer`](/mojo/stdlib/io/write/Writer)): Type parameter representing a Writer implementation. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The 2D layout to format. * ​writer (`W`): The writer to output the formatted layout to.
--- ## hierarchical_unzip
`hierarchical_unzip(layout_a: Layout, tiler: List[Layout]) -> Layout` Hierarchically unzips a layout according to a list of layouts. This function creates a hierarchical layout by unzipping the first layout according to the layouts in the tiler list. It's useful for decomposing a layout into hierarchical components for more efficient memory access patterns or to enable specialized tensor operations. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import hierarchical_unzip # Create a layout to unzip var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2))) var result = hierarchical_unzip(base, tilers) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be unzipped. * ​tiler ([`List`](/mojo/stdlib/collections/list/List)): A list of layouts defining the unzipping patterns. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical unzipping with components from both the original layout and the tiler layouts. `hierarchical_unzip(layout_a: Layout, layout_b: Layout) -> Layout` Hierarchically unzips a layout according to another layout. This function creates a hierarchical layout by unzipping the first layout according to the second layout. It's a fundamental operation for decomposing a layout into hierarchical components, which enables more efficient memory access patterns for various tensor operations. Example: ```mojo from layout import Layout, IntTuple from layout.layout import hierarchical_unzip # Create layouts var base = Layout.row_major(6, 8) var pattern = Layout(IntTuple(2, 2)) var result = hierarchical_unzip(base, pattern) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be unzipped. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout defining the unzipping pattern. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical unzipping of layout\_a according to the pattern defined by layout\_b.
--- ## layout (3)
Provides a high-performance tensor layout system for memory mapping and indexing. The layout module implements a comprehensive system for describing memory layouts of multi-dimensional tensors, enabling efficient mapping between logical tensor coordinates and physical memory locations. This is a critical component for high-performance tensor operations in machine learning and scientific computing. These low-level primitives require careful use to avoid errors. Understanding the relationship between tensor shapes, strides, and memory layout is essential for effective use. Key components: * `LayoutTrait`: Core trait defining the interface for all layout types * `Layout`: Primary struct implementing memory layout with shape and stride information * Layout algebra: Functions for composing, dividing, and transforming layouts * Tiling operations: Functions for hierarchical decomposition of layouts Performance features: * Zero-cost abstractions for mapping between logical and physical indices * Support for both compile-time and runtime-determined shapes * Efficient memory access patterns through layout transformations * Hierarchical tiling for cache-friendly memory access Common use cases: * Defining memory layouts for tensors with different storage formats (row-major, column-major) * Implementing efficient tensor operations with optimal memory access patterns * Supporting hardware-specific memory layouts for accelerators * Enabling zero-copy tensor views and reshaping operations Example: ```mojo from layout import Layout, IntTuple from layout.layout import blocked_product # Create a 3x4 row-major layout var layout = Layout.row_major(3, 4) # Access the memory location for logical coordinates (1, 2) var memory_idx = layout([1, 2]) # Create a tiled layout for blocked matrix multiplication var tiled = blocked_product(layout, Layout([2, 2])) ``` ## `comptime` values ### `LayoutList` `comptime LayoutList = List[Layout]` Type alias for a list of Layout objects. ## Structs * [​`Layout`](./Layout): Represents a memory layout for multi-dimensional data. ## Traits * [​`LayoutTrait`](./LayoutTrait): Defines the interface for mapping between logical coordinates and memory indices. ## Functions * [​`apply_tiler`](./apply_tiler): Applies a layout transformation function to each element of a layout with a tiler. * [​`blocked_product`](./blocked_product): Creates a blocked layout by combining two layouts. * [​`coalesce`](./coalesce): Simplifies a layout by combining dimensions with contiguous strides. * [​`complement`](./complement): Computes the complement layout for a given layout. * [​`composition`](./composition): Composes two layouts to create a new layout. * [​`cosize`](./cosize): Returns the size of the memory region spanned by the layout. * [​`downcast`](./downcast): Splits elements in a layout to create a finer layout without changing the total number of elements so that the alignment is preserved. * [​`expand_modes_alike`](./expand_modes_alike): Aligns two shape-stride pairs to have the same hierarchical structure. * [​`expand_strides`](./expand_strides): Expands a scalar stride into a stride tuple matching a shape tuple. * [​`format_layout`](./format_layout): Formats a 2D layout as a table and writes it to the specified writer. * [​`hierarchical_unzip`](./hierarchical_unzip): Hierarchically unzips a layout according to a list of layouts. * [​`is_contiguous_dim`](./is_contiguous_dim): Checks if a flat layout is contiguous in a specific dimension. * [​`is_row_major`](./is_row_major): Checks if a layout has row-major ordering for the specified rank. * [​`logical_divide`](./logical_divide): Divides a layout into blocks according to another layout. * [​`logical_product`](./logical_product): Creates a product of two layouts. * [​`make_layout`](./make_layout): Creates a composite layout by concatenating multiple layouts. * [​`make_ordered_layout`](./make_ordered_layout): Creates a layout with strides ordered according to a specified traversal order. * [​`MakeLayoutList`](./MakeLayoutList): Creates a list containing two layouts. * [​`MakeTileLayoutList`](./MakeTileLayoutList): Creates a list of layouts for tiling operations. * [​`print_layout`](./print_layout): Prints a 2D layout to the standard output. * [​`right_inverse`](./right_inverse): Creates a right inverse of a layout. * [​`size`](./size): Returns the total number of elements in the layout's domain. * [​`sublayout`](./sublayout): Creates a sublayout by selecting specific dimensions from a layout. * [​`tile_to_shape`](./tile_to_shape): Creates a layout by tiling a base layout to match a target shape. * [​`upcast`](./upcast): Fuses consecutive elements in a layout to create a coarser layout. * [​`zip_modes`](./zip_modes): Combines corresponding modes from two layouts. * [​`zipped_divide`](./zipped_divide): Divides a layout into blocks according to another layout.
--- ## is_contiguous_dim
`is_contiguous_dim(layout: Layout, dim: Int) -> Bool` Checks if a flat layout is contiguous in a specific dimension. This function checks if a flat layout is contiguous in a specified dimension, considering both positive strides and zero strides with a single element. The latter case is necessary for coalesced layouts. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to check. * ​dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the layout is contiguous in the specified dimension, False otherwise.
--- ## is_row_major
`is_row_major[rank: Int](layout: Layout) -> Bool` Checks if a layout has row-major ordering for the specified rank. A row-major layout has strides that decrease from left to right, with the rightmost dimension having a stride of 1. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The expected rank of the layout. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the layout has row-major ordering for the specified rank, False otherwise.
--- ## logical_divide
`logical_divide(layout_a: Layout, _layout_b: Layout) -> Layout` Divides a layout into blocks according to another layout. This function creates a hierarchical layout by dividing the first layout according to the second layout. It's useful for creating blocked or tiled representations of tensors. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be divided. * ​\_layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout defining the division pattern. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical division. `logical_divide(layout_a: Layout, tiler: List[Layout]) -> Layout` Divides a layout into blocks according to a list of layouts. This is a variant of logical\_divide that works with a list of layouts for more complex tiling patterns. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be divided. * ​tiler ([`List`](/mojo/stdlib/collections/list/List)): A list of layouts defining the division patterns. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical division.
--- ## logical_product
`logical_product(_layout_a: Layout, layout_b: Layout) -> Layout` Creates a product of two layouts. This function creates a hierarchical layout by taking the logical product of two layouts. It's a fundamental operation for creating blocked or tiled layouts. **Args:** * ​\_layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the logical product of the two layouts. `logical_product(layout_a: Layout, tiler: List[Layout]) -> Layout` Creates a product of a layout with a list of layouts. This is a variant of logical\_product that works with a list of layouts for more complex tiling patterns. It applies the logical\_product operation to each element of the layout with the corresponding element in the tiler list. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import logical_product # Create a product of a layout with a list of layouts var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2))) var result = logical_product(base, tilers) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to create products with. * ​tiler ([`List`](/mojo/stdlib/collections/list/List)): A list of layouts defining the product patterns. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the logical product with the tiler layouts.
--- ## make_layout
`make_layout(*layouts: Layout) -> Layout` Creates a composite layout by concatenating multiple layouts. This function combines multiple layouts into a single layout by concatenating their shapes and strides. The resulting layout represents a hierarchical structure where each input layout becomes a component of the output layout. Example: ```mojo from layout import Layout, IntTuple from layout.layout import make_layout var layout1 = Layout(IntTuple(2, 3), IntTuple(3, 1)) var layout2 = Layout(IntTuple(4, 5), IntTuple(5, 1)) var combined = make_layout(layout1, layout2) # Result: Layout with shape ((2, 3), (4, 5)) and stride ((3, 1), (5, 1)) ``` **Args:** * ​\*layouts ([`Layout`](/mojo/kernels/layout/layout/Layout)): Variable number of `Layout` objects to combine. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new Layout with concatenated shapes and strides from the input layouts. `make_layout(layout_a: Layout, layout_b: Layout) -> Layout` Creates a composite layout from two layouts. This is a specialized version of make\_layout that takes exactly two layouts and combines them into a single layout. This function exists as a workaround for compiler limitations. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout to include in the composite. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout to include in the composite. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new `Layout` with concatenated shapes and strides from the input layouts.
--- ## make_ordered_layout
`make_ordered_layout(shape: IntTuple, order: IntTuple) -> Layout` Creates a layout with strides ordered according to a specified traversal order. This function generates a compact (bijective) layout where the stride values follow the traversal order specified by the order parameter. This allows creating layouts with custom memory traversal patterns while maintaining a compact memory representation. Example: ```mojo from layout import IntTuple, Layout from layout.layout import make_ordered_layout # Create a layout with shape (2,3,4,5) where dimensions are traversed # in the order: dim0, dim3, dim2, dim1 var layout = make_ordered_layout( IntTuple(2, 3, 4, 5), IntTuple(1, 4, 3, 2) ) # Result: Layout with shape (2,3,4,5) and stride (1,40,10,2) ``` **Args:** * ​shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of the layout. * ​order ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The traversal order priority (lower values indicate higher priority). **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A `Layout` with the specified shape and strides ordered according to the traversal order.
--- ## print_layout
`print_layout(layout: Layout)` Prints a 2D layout to the standard output. This function visualizes a 2D layout by printing a formatted table showing the memory indices for each logical coordinate. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The 2D layout to print.
--- ## right_inverse
`right_inverse(layout: Layout) -> Layout` Creates a right inverse of a layout. The right inverse of a layout maps memory indices back to logical coordinates. This is useful for converting between different memory layouts. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to invert. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the right inverse of the input layout.
--- ## size (Layout)
`size(l: Layout) -> Int` Returns the total number of elements in the layout's domain. This is a standalone function equivalent to the Layout.size() method. **Args:** * ​l ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to calculate the size for. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total number of elements in the layout.
--- ## sublayout
`sublayout(layout: Layout, *modes: Int) -> Layout` Creates a sublayout by selecting specific dimensions from a layout. This function extracts a subset of dimensions from a layout to create a new layout with lower rank. For example, from a 3D layout, you could extract a 2D layout containing only the first and third dimensions. Example: From a layout with shape (3,4,5), sublayout(layout, 0, 2) would create a layout with shape (3,5). **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The source layout to extract dimensions from. * ​\*modes ([`Int`](/mojo/stdlib/builtin/int/Int)): The indices of dimensions to include in the sublayout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout containing only the specified dimensions.
--- ## tile_to_shape
`tile_to_shape(tile: Layout, target_shape: IntTuple, order: Optional[IntTuple] = None) -> Layout` Creates a layout by tiling a base layout to match a target shape. This function creates a hierarchical layout by repeating a tile layout to match a target shape. It calculates how many times the tile needs to be repeated in each dimension to reach the target shape, and creates a tiler layout with this information. Example: ```mojo from layout import Layout, IntTuple from layout.layout import tile_to_shape # Create a 2x2 tile layout var tile = Layout.row_major(2, 2) # Tile it to create a 6x4 layout var tiled = tile_to_shape(tile, IntTuple(6, 4)) # Result: A layout with 3x2 tiles of size 2x2 each ``` **Args:** * ​tile ([`Layout`](/mojo/kernels/layout/layout/Layout)): The base layout to be tiled. * ​target\_shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The desired final shape to tile to. * ​order ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Optional memory ordering for the tiler layout. If None, defaults to column-major ordering. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the tiled structure that matches the target shape.
--- ## upcast
`upcast[check: Bool = True](layout: Layout, factor: Int) -> Layout` Fuses consecutive elements in a layout to create a coarser layout. This function is useful for converting between different data type granularities, such as from bytes to larger data types like bfloat16 or tf32. **Parameters:** * ​check ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to check for incompatible factors. **Args:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to upcast. * ​factor ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of consecutive elements to fuse into one. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout with adjusted shape and stride for the coarser granularity.
--- ## zip_modes
`zip_modes(layout_a: Layout, layout_b: Layout) -> Layout` Combines corresponding modes from two layouts. This function creates a new layout by combining corresponding dimensions from two layouts. If a dimension in layout\_b has a non-positive shape, the corresponding dimension from layout\_a is used directly. **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The first layout. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The second layout. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout with combined dimensions from both input layouts.
--- ## zipped_divide
`zipped_divide(layout_a: Layout, layout_b: Layout) -> Layout` Divides a layout into blocks according to another layout. This function creates a hierarchical layout by dividing the first layout according to the second layout. It's an alias for hierarchical\_unzip that provides a more intuitive name for the division operation. This is useful for creating blocked or tiled representations of tensors. Example: ```mojo from layout import Layout, IntTuple from layout.layout import zipped_divide # Create layouts var base = Layout.row_major(6, 8) var pattern = Layout(IntTuple(2, 2)) var result = zipped_divide(base, pattern) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be divided. * ​layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout defining the division pattern. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical division of layout\_a according to layout\_b. `zipped_divide(layout_a: Layout, tiler: List[Layout]) -> Layout` Divides a layout into blocks according to a list of layouts. This function creates a hierarchical layout by dividing the first layout according to the layouts in the tiler list. It's an alias for hierarchical\_unzip that provides a more intuitive name for the division operation when working with multiple tiling patterns. Example: ```mojo from layout import Layout, LayoutList, IntTuple from layout.layout import zipped_divide # Create layouts var base = Layout.row_major(6, 8) var tilers = LayoutList() tilers.append(Layout(IntTuple(2, 2))) var result = zipped_divide(base, tilers) ``` **Args:** * ​layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to be divided. * ​tiler ([`List`](/mojo/stdlib/collections/list/List)): A list of layouts defining the division patterns. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): A new layout representing the hierarchical division of layout\_a according to the patterns in tiler.
--- ## LayoutTensor
`@register_passable(trivial)` `struct LayoutTensor[mut: Bool, //, dtype: DType, layout: Layout, origin: Origin[mut], /, *, address_space: AddressSpace = AddressSpace.GENERIC, element_layout: Layout = Layout(IntTuple(1), IntTuple(1)), layout_int_type: DType = _get_layout_type(layout, address_space), linear_idx_type: DType = _get_index_type(layout, address_space), masked: Bool = False, alignment: Int = align_of[dtype]()]` A high-performance tensor with explicit memory layout and hardware-optimized access patterns. `LayoutTensor` provides a powerful abstraction for multi-dimensional data with precise control over memory organization. It supports various memory layouts (row-major, column-major, tiled), hardware-specific optimizations, and efficient parallel access patterns. Example: ```mojo from layout import Layout, LayoutTensor # Create tensor on CPU using InlineArray to allocate storage space. var storage = InlineArray[Float32, 5 * 4](uninitialized=True) var tensor_5x4 = LayoutTensor[DType.float32, Layout.row_major(5, 4)](storage) ``` ## Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The inferred mutability of the underlying pointer. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the underlying pointer. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of the tensor. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the underlying pointer. * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The address space of the underlying pointer. * ​element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout of each element in the tensor. * ​layout\_int\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer type of each dimension of runtime layout. * ​linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer type of the index pointing to memory locations. * ​masked ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If true the tensor is masked and runtime layouts determine the shape. * ​alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Alignment of the data pointer. ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin]`): Pointer to the underlying memory buffer containing the tensor data. This pointer respects the specified address space, alignment, mutability, and origin tracking for memory safety and performance optimization. * ​runtime\_layout (`LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].RuntimeLayoutType`): Runtime representation of the tensor's memory layout. Handles both compile-time and runtime-determined dimensions, enabling efficient mapping between logical tensor coordinates and physical memory locations. * ​runtime\_element\_layout (`LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].RuntimeElementLayoutType`): Runtime representation of each element's internal layout. Used when elements themselves have structure, such as in blocked or tiled layouts. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable), [`_Expable`](/mojo/stdlib/math/math/_Expable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AddressSpaceCastType` `comptime AddressSpaceCastType[address_space: AddressSpace = address_space] = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for address-space-cast result tensors. #### Parameters * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The target address space for the result tensor. ### `BitcastType` `comptime BitcastType[new_dtype: DType, /, address_space: AddressSpace = address_space, element_layout: Layout = element_layout] = LayoutTensor[new_dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for bitcast result tensors. #### Parameters * ​new\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The target data type to cast to. * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The address space for the result tensor. * ​element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The element layout for the result tensor. ### `CoalesceType` `comptime CoalesceType[element_layout: Layout] = LayoutTensor[dtype, coalesce(layout, False), origin, address_space=address_space, element_layout=element_layout]` Type alias for coalesced result tensors. #### Parameters * ​element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The element layout for the coalesced tensor. ### `CompositionType` `comptime CompositionType[rhs_layout: Layout, dst_layout: Layout = composition(layout, rhs_layout)] = LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for composed layout tensor types. #### Parameters * ​rhs\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to compose with. * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The resulting composed layout. ### `CornerCoordsType` `comptime CornerCoordsType = IndexList[len[IntTuple](flatten(layout.shape)), element_type=layout_int_type]` Index list type for corner coordinates. ### `device_type` `comptime device_type = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` The device-side type representation. ### `DistributeType` `comptime DistributeType[threads_layout: Layout, axis: OptionalReg[Int] = None] = LayoutTensor[dtype, _compute_distribute_layout[layout, threads_layout, axis]()[1], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _distribute_is_masked[layout, threads_layout, axis]() if is_nvidia_gpu() else False]` Type alias for distributed tensor types. #### Parameters * ​threads\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout describing thread distribution. * ​axis ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional axis to distribute along. ### `DynamicSplitType` `comptime DynamicSplitType[axis: Int = 0] = LayoutTensor[dtype, layout.make_shape_unknown[axis](), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for dynamic split result tensors. #### Parameters * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis along which to split. ### `element_size` `comptime element_size = element_layout.size()` The number of scalar values in each element of the tensor. ### `element_type` `comptime element_type = SIMD[dtype, LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].element_size]` The SIMD vector type used for vectorized operations on tensor elements. ### `FlattenedType` `comptime FlattenedType = LayoutTensor[dtype, Layout(IntTuple(-1)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for flattened tensor types. ### `GenericAddressSpaceLayoutTensor` `comptime GenericAddressSpaceLayoutTensor = LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` LayoutTensor variant using generic address space. ### `GenericLayoutTensorType` `comptime GenericLayoutTensorType = LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` LayoutTensor type with generic address space. ### `idx_list_t` `comptime idx_list_t[rank: Int = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank] = IndexList[rank, element_type=linear_idx_type]` Type alias for index lists of the tensor's rank. #### Parameters * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the index list. ### `MutableAnyType` `comptime MutableAnyType = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Mutable LayoutTensor type with MutAnyOrigin. ### `num_strides` `comptime num_strides = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].RuntimeLayoutType.StrideType.scalar_length` Number of stride values in the layout. ### `OriginCastType` `comptime OriginCastType[mut: Bool, origin: Origin[mut]] = LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for origin-cast result tensors. #### Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the result tensor is mutable. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin for the result tensor. ### `rank` `comptime rank = layout.rank()` The number of dimensions in the tensor's layout. ### `ReshapeType` `comptime ReshapeType[dst_layout: Layout] = LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for reshaped tensor types. #### Parameters * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout for the reshaped tensor. ### `RuntimeElementLayoutType` `comptime RuntimeElementLayoutType = RuntimeLayout[element_layout, element_type=DType.int32, linear_idx_type=linear_idx_type]` Type alias for the runtime element layout. ### `RuntimeLayoutType` `comptime RuntimeLayoutType = RuntimeLayout[layout, element_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for the runtime layout. ### `ShapeVectorizedType` `comptime ShapeVectorizedType[origin: ImmutOrigin, vector_shape: IntTuple, linear_vectorize: Bool] = LayoutTensor[dtype, coalesce(LayoutTensor._tuple_divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](vector_shape, linear_vectorize)[1], True), origin, address_space=address_space, element_layout=LayoutTensor._tuple_divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](vector_shape, linear_vectorize)[0], layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for shape-vectorized tensor types. #### Parameters * ​origin ([`ImmutOrigin`](/mojo/stdlib/builtin/type_aliases/#immutorigin)): The origin of the result tensor. * ​vector\_shape ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The shape of each vector unit. * ​linear\_vectorize ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to vectorize in a linear manner. ### `SIMDTileType` `comptime SIMDTileType[tile_size: Int] = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_size, simd_width_of[dtype]()]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_size, simd_width_of[dtype]()](), alignment=alignment]` Type alias for SIMD-sized tile tensors. #### Parameters * ​tile\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the tile along the tiled axis. ### `SIMDVectorizedType` `comptime SIMDVectorizedType = LayoutTensor[dtype, coalesce(LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, 1, simd_width_of[dtype]()]()[1], True), origin, address_space=address_space, element_layout=LayoutTensor._divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, 1, simd_width_of[dtype]()]()[0], layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Result type for SIMD-width vectorization. ### `SliceType` `comptime SliceType[d0_slice: Slice, d1_slice: Slice] = LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, d1_slice), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for 2D slice result tensors. #### Parameters * ​d0\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the first dimension. * ​d1\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the second dimension. ### `SliceType1D` `comptime SliceType1D[d0_slice: Slice, slice_indices: IndexList[1], __offset_dims: Int = (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 1)] = LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, slice_indices.__getitem__[1, DType.int64, Int](0)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for 1D slice result tensors from higher-rank tensors. #### Parameters * ​d0\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the selected dimension. * ​slice\_indices ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Index of the dimension to slice. * ​\_\_offset\_dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of fixed dimensions. ### `SliceType2D` `comptime SliceType2D[d0_slice: Slice, d1_slice: Slice, slice_indices: IndexList[2], __offset_dims: Int = (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)] = LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, d1_slice, slice_indices.__getitem__[2, DType.int64, Int](0), slice_indices.__getitem__[2, DType.int64, Int](1)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for 2D slice result tensors from higher-rank tensors. #### Parameters * ​d0\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the first selected dimension. * ​d1\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the second selected dimension. * ​slice\_indices ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Indices of the two dimensions to slice. * ​\_\_offset\_dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of fixed dimensions. ### `SplitElementType` `comptime SplitElementType[count: Int, axis: Int = 0] = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, (layout.shape[axis].value() // count), axis]()[0], origin, address_space=address_space, element_layout=element_layout, alignment=alignment]` Type alias for split element tensors. #### Parameters * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of portions to split into. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis along which to split. ### `StackTensorType` `comptime StackTensorType = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` LayoutTensor type for stack-allocated tensors. ### `StaticSplitType` `comptime StaticSplitType[count: Int, axis: Int = 0] = StaticTuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, (layout.shape[axis].value() // count), axis]()[0], origin, address_space=address_space, element_layout=element_layout, alignment=alignment], count]` Type alias for static split result tuples. #### Parameters * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of portions to split into. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis along which to split. ### `storage_size` `comptime storage_size = (size_of[dtype]() * layout.size())` Total storage size in bytes for the tensor data. ### `TiledIteratorType` `comptime TiledIteratorType[*tile_sizes: Int, *, axis: Int = 0] = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes]()]` Type alias for tiled iterator types. #### Parameters * ​\*tile\_sizes ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensions of each tile along each axis. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis along which to iterate. ### `TileType` `comptime TileType[*tile_sizes: Int] = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes](), alignment=alignment]` The tile type returned by the `tile()` method given the specified set of tile sizes. #### Parameters * ​\*tile\_sizes ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensions of each tile along each axis of the tensor. ### `TransposeType` `comptime TransposeType = LayoutTensor[dtype, layout.transpose(), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Result type for transpose operations. ### `VectorizedType` `comptime VectorizedType[*vector_shape: Int] = LayoutTensor[dtype, coalesce(LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, vector_shape]()[1], True), origin, address_space=address_space, element_layout=LayoutTensor._divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, vector_shape]()[0], layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for vectorized tensor types. #### Parameters * ​\*vector\_shape ([`Int`](/mojo/stdlib/builtin/int/Int)): The shape of each vector unit along each axis. ## Methods ### `__init__` `__init__(span: Span[Scalar[dtype], origin]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericAddressSpaceLayoutTensor` Create a `LayoutTensor` with a `Span`. **Constraints:** Layout must be fully static. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The `Span` pointing to the underlying data. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(span: Span[Scalar[dtype], origin], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericAddressSpaceLayoutTensor` Create a `LayoutTensor` with a `Span` and a runtime layout for the tensor. The runtime layout element type will be casted to the layout tensor layout integer type. **Constraints:** * Element layout must be fully static. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The `Span` pointing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the LayoutTensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(span: Span[Scalar[dtype], origin], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], element_runtime_layout: RuntimeLayout[element_layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericAddressSpaceLayoutTensor` Create a `LayoutTensor` with a `Span`, a runtime layout of the tensor, and the runtime layout of each element. The runtime layout element type will be casted to the layout tensor layout integer type. **Constraints:** * Runtime layout and `LayoutTensor` must have the same bitwidth and index type. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The `Span` pointing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. * ​element\_runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin]) -> Self` Create a `LayoutTensor` with an `UnsafePointer`. **Constraints:** Layout must be fully static. **Args:** * ​unsafe\_ptr (`LegacyUnsafePointer`): The `UnsafePointer` pointing to the underlying data. `__init__(unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> Self` Create a `LayoutTensor` with an `UnsafePointer` and a runtime layout for the tensor. The runtime layout element type will be casted to the layout tensor layout integer type. **Constraints:** Element layout must be fully static. **Args:** * ​unsafe\_ptr (`LegacyUnsafePointer`): The UnsafePointer pointing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the LayoutTensor. `__init__(unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], element_runtime_layout: RuntimeLayout[element_layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> Self` Create a `LayoutTensor` with an `UnsafePointer`, a runtime layout for the tensor, and the runtime layout of each element. The runtime layout element type will be casted to the layout tensor layout integer type. **Args:** * ​unsafe\_ptr (`LegacyUnsafePointer`): The `UnsafePointer` pointing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. * ​element\_runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of each element. `__init__(ref [origin] device_buffer: DeviceBuffer[dtype]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `DeviceBuffer`. The layout must have statically known dimensions. Note that the device buffer memory is on the accelerator device (GPU global memory). Code running on the CPU can use the [`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext) to allocate a `DeviceBuffer` and use that to construct a `LayoutTensor` that can be accessed on the GPU. You cannot directly access data in the `DeviceBuffer` or `LayoutTensor` from the CPU. The following example shows a typical pattern for using `DeviceBuffer` to construct a `LayoutTensor` that you can use on the GPU. ```mojo from gpu.host import DeviceContext, DeviceBuffer from layout import Layout, LayoutTensor comptime dtype = DType.float32 var ctx = DeviceContext() # Allocate buffers var dev_buf = ctx.enqueue_create_buffer[dtype](16) var host_buf = ctx.enqueue_create_host_buffer[dtype](16) # Ensure buffers have been created ctx.synchronize() # Initialize host buffer and copy to device buffer for i in range(16): host_buf[i] = i ctx.enqueue_copy(dev_buf, host_buf) # Create LayoutTensor to use on device comptime layout = Layout.row_major(4, 4) var tensor = LayoutTensor[dtype, layout](dev_buf) ... ``` **Constraints:** * Layout must be fully static. **Args:** * ​device\_buffer ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Contains the underlying data to point to. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref [origin] host_buffer: HostBuffer[dtype]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `HostBuffer`. The layout must have statically known dimensions. The resulting tensor's data can only be accessed on the CPU. ```mojo from gpu.host import DeviceContext, HostBuffer from layout import Layout, LayoutTensor comptime dtype = DType.float32 var ctx = DeviceContext() var dev_buf = ctx.enqueue_create_host_buffer[dtype](8) comptime layout = Layout.row_major(4, 4) var tensor = LayoutTensor[dtype, layout](dev_buf) ``` **Constraints:** * Layout must be fully static. **Args:** * ​host\_buffer ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): Contains the underlying data to point to. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref [origin] device_buffer: DeviceBuffer[dtype], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `DeviceBuffer` and a runtime layout. The runtime layout element type will be casted to the layout tensor layout integer type. The resulting tensor's data can only be accessed on the GPU. **Constraints:** * Element layout must be fully static. **Args:** * ​device\_buffer ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): The `DeviceBuffer` containing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the LayoutTensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref [origin] host_buffer: HostBuffer[dtype], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `HostBuffer` and a runtime layout. The runtime layout element type will be casted to the layout tensor layout integer type. The resulting tensor's data can only be accessed on the CPU. **Constraints:** * Element layout must be fully static. **Args:** * ​host\_buffer ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): The `HostBuffer` containing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref [origin] device_buffer: DeviceBuffer[dtype], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], element_runtime_layout: RuntimeLayout[element_layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `DeviceBuffer`, a runtime layout for the tensor, and the runtime layout of each element. The runtime layout element type will be casted to the layout tensor layout integer type. The resulting tensor's data can only be accessed on the GPU. **Args:** * ​device\_buffer ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): The `DeviceBuffer` containing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. * ​element\_runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) `__init__(ref [origin] host_buffer: HostBuffer[dtype], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], element_runtime_layout: RuntimeLayout[element_layout, element_type=element_type, linear_idx_type=linear_idx_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].GenericLayoutTensorType` Create a `LayoutTensor` from a `HostBuffer`, a runtime layout for the tensor, and the runtime layout of each element. The runtime layout element type will be casted to the layout tensor layout integer type. The resulting tensor's data can only be accessed on the CPU. **Args:** * ​host\_buffer ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): The `HostBuffer` containing to the underlying data. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of the `LayoutTensor`. * ​element\_runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The runtime layout of each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `__getitem__` `__getitem__[*Tys: Indexer](self, *args: *Tys) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].element_type` Retrieves a single element from the tensor at the specified indices. This method provides array-like indexing for the tensor. The number of indices provided must match the rank of the tensor, otherwise an error will occur at runtime. **Parameters:** * ​\*Tys ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the indices. Must implement the `Indexer` trait, and match the rank of the tensor. **Args:** * ​\*args (`*Tys`): The indices specifying the element's position in the tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The element at the specified position with the tensor's data type. `__getitem__(self, crd: RuntimeTuple[S, element_type=element_type]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].element_type` Retrieves a single element from the tensor at the specified indices. This method provides array-like indexing for the tensor. The number of indices provided must match the rank of the tensor, otherwise an error will occur at runtime. **Args:** * ​crd ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The coordinate specifying the element's position in each dimension. For example, in a 3D tensor, you would use (i, j, k). **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The element at the specified position with the tensor's data type. ### `__setitem__` `__setitem__[*Tys: Indexer](self, *args: *Tys, *, val: SIMD[dtype, LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].element_size])` Sets a single element in a tensor at the specified indices. This method provides array-like element assignment for tensors. Notes: * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. **Parameters:** * ​\*Tys ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the indices. Must implement the `Indexer` trait, and match the rank of the tensor. **Args:** * ​\*args (`*Tys`): The indices specifying the element's position in the tensor. * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to write to the tensor at the specified position. ### `__add__` `__add__(self, other: Scalar[dtype]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Add a scalar value to each element of the tensor. Performs an elementwise addition operation, adding the scalar value to each element in the tensor. This operation creates a new tensor with the results. Performance: * This operation creates a copy of the tensor before performing the addition. * For in-place addition, use the `__iadd__` method instead (`+=` operator). **Args:** * ​other ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to add to each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the addition operation. `__add__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Add another tensor to this tensor elementwise. Performs an elementwise addition between this tensor and another tensor. This operation creates a new tensor with the results. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation creates a copy of the tensor before performing the addition. * For in-place addition, use the `__iadd__` method instead (`+=` operator). **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to add to this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the addition operation. ### `__sub__` `__sub__(self, other: Scalar[dtype]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Subtract a scalar value from each element of the tensor. Performs an elementwise subtraction operation, subtracting the scalar value from each element in the tensor. This operation creates a new tensor with the results. Performance: * This operation creates a copy of the tensor before performing the subtraction. * For in-place subtraction, use the `__isub__` method instead (`-=` operator). **Args:** * ​other ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to subtract from each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the subtraction operation. `__sub__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Subtract another tensor from this tensor elementwise. Performs an elementwise subtraction between this tensor and another tensor. This operation creates a new tensor with the results. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation creates a copy of the tensor before performing the subtraction. * For in-place subtraction, use the `__isub__` method instead (`-=` operator). **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to subtract from this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the subtraction operation. ### `__mul__` `__mul__(self, other: Scalar[dtype]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Multiply each element of the tensor by a scalar value. Performs an elementwise multiplication operation, multiplying each element in the tensor by the scalar value. This operation creates a new tensor with the results. Performance: * This operation creates a copy of the tensor before performing the multiplication. * For in-place multiplication, use the `__imul__` method instead (`*=` operator). **Args:** * ​other ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to multiply with each element. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the multiplication operation. `__mul__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Multiply this tensor with another tensor elementwise. Performs an elementwise multiplication (Hadamard product) between this tensor and another tensor. This operation creates a new tensor with the results. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Note: This is NOT a matrix multiplication operation. For matrix multiplication, use the appropriate matmul function instead. Performance: * This operation creates a copy of the tensor before performing the multiplication. * For in-place multiplication, use the `__imul__` method instead (`*=` operator). **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to multiply with this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the elementwise multiplication. ### `__truediv__` `__truediv__(self, other: Scalar[dtype]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Divide each element of the tensor by a scalar value. Performs an elementwise division operation, dividing each element in the tensor by the scalar value. This operation creates a new tensor with the results. Performance: * This operation creates a copy of the tensor before performing the division. * For in-place division, use the `__itruediv__` method instead (`/=` operator). Notes: * Division by zero will result in undefined behavior or errors depending on the dtype. * For integer dtypes, this performs integer division. **Args:** * ​other ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to divide each element by. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the division operation. `__truediv__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Divide this tensor by another tensor elementwise. Performs an elementwise division between this tensor and another tensor. This operation creates a new tensor with the results. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation creates a copy of the tensor before performing the division. * For in-place division, use the `__itruediv__` method instead (`/=` operator). Notes: * Division by zero will result in undefined behavior or errors depending on the dtype. * For integer dtypes, this performs integer division. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to divide this tensor by. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the results of the division operation. ### `__iadd__` `__iadd__(self, other: Scalar[dtype])` Add a scalar value to each element of the tensor in-place. Performs an elementwise addition operation, adding the scalar value to each element in the tensor. This operation modifies the tensor in-place. Performance: * This operation modifies the tensor directly without creating a copy. **Args:** * ​other ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to add to each element. `__iadd__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Add another tensor to this tensor elementwise in-place. Performs an elementwise addition between this tensor and another tensor. This operation modifies the tensor in-place. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation modifies the tensor directly without creating a copy. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to add to this tensor. ### `__isub__` `__isub__(self, other: Scalar[dtype])` Subtract a scalar value from each element of the tensor in-place. Performs an elementwise subtraction operation, subtracting the scalar value from each element in the tensor. This operation modifies the tensor in-place. Performance: * This operation modifies the tensor directly without creating a copy. **Args:** * ​other ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to subtract from each element. `__isub__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Subtract another tensor from this tensor elementwise in-place. Performs an elementwise subtraction between this tensor and another tensor. This operation modifies the tensor in-place. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation modifies the tensor directly without creating a copy. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to subtract from this tensor. ### `__imul__` `__imul__(self, other: Scalar[dtype])` Multiply each element of the tensor by a scalar value in-place. Performs an elementwise multiplication operation, multiplying each element in the tensor by the scalar value. This operation modifies the tensor in-place. Performance: * This operation modifies the tensor directly without creating a copy. **Args:** * ​other ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to multiply with each element. `__imul__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Multiply this tensor with another tensor elementwise in-place. Performs an elementwise multiplication (Hadamard product) between this tensor and another tensor. This operation modifies the tensor in-place. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Note: This is NOT a matrix multiplication operation. For matrix multiplication, use the appropriate matmul function instead. Performance: * This operation modifies the tensor directly without creating a copy. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to multiply with this tensor. ### `__itruediv__` `__itruediv__(self, other: Scalar[dtype])` Divide each element of the tensor by a scalar value in-place. Performs an elementwise division operation, dividing each element in the tensor by the scalar value. This operation modifies the tensor in-place. Performance: * This operation modifies the tensor directly without creating a copy. Notes: * Division by zero will result in undefined behavior or errors depending on the dtype. * For integer dtypes, this performs integer division. **Args:** * ​other ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to divide each element by. `__itruediv__[other_layout: Layout](self, other: LayoutTensor[dtype, other_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Divide this tensor by another tensor elementwise in-place. Performs an elementwise division between this tensor and another tensor. This operation modifies the tensor in-place. Limited broadcasting is supported: * For tensors of the same rank, shapes must match exactly. * For rank-1 to rank-2 broadcasting, the rank-1 tensor's dimension must match the corresponding dimension of the rank-2 tensor. Performance: * This operation modifies the tensor directly without creating a copy. Notes: * Division by zero will result in undefined behavior or errors depending on the dtype. * For integer dtypes, this performs integer division. **Parameters:** * ​other\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the other tensor. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to divide this tensor by. ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). **Returns:** `String`: The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. **Returns:** `String`: The device type's name. ### `__merge_with__` `__merge_with__[other_type: AnyStruct[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]]](self) -> LayoutTensor[dtype, layout, origin_of((mutcast origin._mlir_origin), (mutcast origin._mlir_origin)), address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Returns a tensor merged with the specified `other_type`. **Parameters:** * ​other\_type (`AnyStruct`): The type of the tensor to merge with. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A tensor merged with the specified `other_type`. ### `bitcast` `bitcast[new_dtype: DType, /, target_address_space: AddressSpace = address_space, _element_layout: Layout = element_layout](self) -> LayoutTensor[new_dtype, layout, origin, address_space=target_address_space, element_layout=_element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Bitcast the underlying pointer to a new data type. **Parameters:** * ​new\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The new data type it is casting to. * ​target\_address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The address space of the returned `LayoutTensor`. * ​\_element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The element layout of the returned `LayoutTensor`. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new `LayoutTensor` with the same memory location but with the specified data type, address space, and element layout. ### `as_any_origin` `as_any_origin(self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Casts the origin of the mutable `LayoutTensor` to `MutAnyOrigin`. This requires the tensor to already be mutable as casting mutability is inherently very unsafe. It is usually preferred to maintain concrete origin values instead of using `MutAnyOrigin`. However, if it is needed, keep in mind that `MutAnyOrigin` can alias any memory value, so Mojo's ASAP destruction will not apply during the lifetime of the tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A pointer with the origin set to `MutAnyOrigin`. `as_any_origin(self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, ImmutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Casts the origin of the immutable `LayoutTensor` to `ImmutAnyOrigin`. It is usually preferred to maintain concrete origin values instead of using `ImmutAnyOrigin`. However, if it is needed, keep in mind that `ImmutAnyOrigin` can alias any memory value, so Mojo's ASAP destruction will not apply during the lifetime of the tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A tensor with the origin set to `ImmutAnyOrigin`. ### `address_space_cast` `address_space_cast[target_address_space: AddressSpace = address_space](self) -> LayoutTensor[dtype, layout, origin, address_space=target_address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Changes the address space of the `LayoutTensor`. **Parameters:** * ​target\_address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The new address space. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new `LayoutTensor` object with the same type and origin as the original `LayoutTensor`, and the new specified address\_space. ### `get_immutable` `get_immutable(self) -> LayoutTensor[dtype, layout, origin_of((muttoimm origin._mlir_origin)), address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Return an immutable version of this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A `LayoutTensor` covering the same elements, but without mutability. ### `ptr_at_offset` `ptr_at_offset(self, coords: IndexList[size, element_type=element_type]) -> LegacyUnsafePointer[Scalar[dtype], address_space=address_space]` Get a pointer offset at the given flattened coordinates. **Args:** * ​coords ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): A flattened list of the offset coordinates. **Returns:** `LegacyUnsafePointer`: A pointer offset at the given flattened coordinates. ### `__exp__` `__exp__(self) -> Self` Computes element-wise exponential function. Returns a new tensor containing the [element-wise exponential](/mojo/stdlib/math/math/exp/) of the input tensor. **Returns:** `Self`: A new tensor containing the element-wise exponential. ### `load` `load[width: Int](self, m: Int, n: Int) -> SIMD[dtype, width]` Load a SIMD vector from the tensor at the specified 2D coordinates. Performs a vectorized load operation from the tensor's memory, retrieving `width` consecutive elements starting at position (m, n). This method enables efficient SIMD operations on tensor data. Performance: * Uses unaligned memory access which may be slower on some architectures. * For aligned access, use `aligned_load` instead when data alignment is guaranteed. * The load operation is optimized based on the tensor's memory layout. Notes: * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * The elements are loaded according to the tensor's stride configuration. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to load into the SIMD vector. Should match the target hardware's vector width for optimal performance. **Args:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): The row index (first dimension). * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The column index (second dimension). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector containing 'width' consecutive elements from the tensor. `load[width: Int](self, coords: IndexList[size, element_type=element_type]) -> SIMD[dtype, width]` Load a SIMD vector from the tensor at the specified 2D coordinates. Performs a vectorized load operation from the tensor's memory, retrieving `width` consecutive elements starting at position (m, n). This method enables efficient SIMD operations on tensor data. Performance: * Uses unaligned memory access which may be slower on some architectures. * For aligned access, use `aligned_load` instead when data alignment is guaranteed. * The load operation is optimized based on the tensor's memory layout. Notes: * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * The elements are loaded according to the tensor's stride configuration. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to load into the SIMD vector. Should match the target hardware's vector width for optimal performance. **Args:** * ​coords ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The coordinates to index. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector containing 'width' consecutive elements from the tensor. ### `prefetch` `prefetch(self, m: Int, n: Int)` Prefetch tensor data at the specified 2D coordinates into cache. Issues a software prefetch hint to the processor to load the data at position (m, n) into the cache hierarchy. This can improve performance by reducing memory latency for subsequent accesses to the same location. Performance: * Prefetching is a performance hint and does not guarantee data will be cached. * Most effective when issued sufficiently ahead of the actual data access. * Uses high locality prefetch to the data cache, optimized for data that will be accessed multiple times. * Can reduce memory access latency by 50-90% when used correctly. Notes: * Excessive prefetching can pollute the cache and degrade performance. * Most beneficial for predictable access patterns that would otherwise cause cache misses. * No operation is performed on the prefetched data. **Args:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): The row index (first dimension). * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The column index (second dimension). `prefetch(self, coords: IndexList[size, element_type=element_type])` Prefetch tensor data at the specified coordinates into cache. Issues a software prefetch hint to the processor to load the data at coords into the cache hierarchy. This can improve performance by reducing memory latency for subsequent accesses to the same location. Performance: * Prefetching is a performance hint and does not guarantee data will be cached. * Most effective when issued sufficiently ahead of the actual data access. * Uses high locality prefetch to the data cache, optimized for data that will be accessed multiple times. * Can reduce memory access latency by 50-90% when used correctly. Notes: * Excessive prefetching can pollute the cache and degrade performance. * Most beneficial for predictable access patterns that would otherwise cause cache misses. * No operation is performed on the prefetched data. **Args:** * ​coords ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The indices. ### `aligned_load` `aligned_load[width: Int](self, m: Int, n: Int) -> SIMD[dtype, width]` Load a SIMD vector with alignment guarantees from the tensor. Performs an aligned vectorized load operation from the tensor's memory, retrieving `width` consecutive elements starting at position (m, n). The alignment is automatically calculated based on the SIMD width and dtype. Performance: * Uses aligned memory access which is faster than unaligned access on most architectures. * The alignment is automatically calculated based on the SIMD width and dtype. * Can be up to 2x faster than unaligned loads on architectures that require alignment. Notes: * The caller must ensure that the memory at (m, n) is properly aligned. Misaligned access with this method may cause hardware exceptions on some architectures. * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to load into the SIMD vector. Should match the target hardware's vector width for optimal performance. **Args:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): The row index (first dimension). * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The column index (second dimension). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector containing 'width' consecutive elements from the tensor. ### `store` `store[width: Int](self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], m: Int, n: Int, val: SIMD[dtype, width])` Store a SIMD vector to the tensor at the specified 2D coordinates. Performs a vectorized store operation to the tensor's memory, writing 'width' consecutive elements starting at position (m, n). This method enables efficient SIMD operations on tensor data. Performance: * Uses unaligned memory access which may be slower on some architectures. * For aligned access, use aligned\_store instead when data alignment is guaranteed. * The store operation is optimized based on the tensor's memory layout. Notes: * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * The elements are stored according to the tensor's stride configuration. * This operation modifies the tensor's data in-place. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements in the SIMD vector to store. Should match the target hardware's vector width for optimal performance. **Args:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): The row index (first dimension) where the store operation begins. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The column index (second dimension) where the store operation begins. * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The SIMD vector containing the values to store in the tensor. `store[width: Int](self, coords: IndexList[size, element_type=element_type], val: SIMD[dtype, width])` Store a SIMD vector to the tensor at the specified ND coordinates. Performs a vectorized store operation to the tensor's memory, writing 'width' consecutive elements starting at position (m, n). This method enables efficient SIMD operations on tensor data. Performance: * Uses unaligned memory access which may be slower on some architectures. * For aligned access, use aligned\_store instead when data alignment is guaranteed. * The store operation is optimized based on the tensor's memory layout. Notes: * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * The elements are stored according to the tensor's stride configuration. * This operation modifies the tensor's data in-place. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements in the SIMD vector to store. Should match the target hardware's vector width for optimal performance. **Args:** * ​coords ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The coordinates to index. * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The SIMD vector containing the values to store in the tensor. ### `aligned_store` `aligned_store[width: Int](self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], m: Int, n: Int, val: SIMD[dtype, width])` Store a SIMD vector with alignment guarantees to the tensor. Performs an aligned vectorized store operation to the tensor's memory, writing `width` consecutive elements starting at position (m, n). The alignment is automatically calculated based on the SIMD width and dtype. Performance: * Uses aligned memory access which is faster than unaligned access on most architectures. * The alignment is automatically calculated based on the SIMD width and dtype. * Can be up to 2x faster than unaligned stores on architectures that require alignment. * Particularly important for streaming stores that bypass the cache. Notes: * The caller must ensure that the memory at (m, n) is properly aligned. Misaligned access with this method may cause hardware exceptions on some architectures. * No bounds checking is performed. Accessing out-of-bounds indices will result in undefined behavior. * This operation modifies the tensor's data in-place. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements in the SIMD vector to store. Should match the target hardware's vector width for optimal performance. **Args:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): The row index (first dimension) where the store operation begins. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The column index (second dimension) where the store operation begins. * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The SIMD vector containing the values to store in the tensor. ### `size` `size(self) -> Int` Get the total number of elements that the tensor can contain. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total number of elements that can be stores in the tensor. ### `stack_allocation` `static stack_allocation[*, stack_alignment: Int = alignment]() -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].StackTensorType` Allocates stack memory for a `LayoutTensor` with a fully static layout. Creates a new `LayoutTensor` instance with memory allocated on the stack rather than the heap. This provides deterministic memory management and potentially better performance for tensors with known sizes at compile time. Performance: * Stack allocation is typically faster than heap allocation. * Proper alignment can significantly improve memory access performance, especially for vectorized operations. * No dynamic memory management overhead (no malloc/free calls). Notes: * Only works with tensors that have fully static layouts known at compile time. * Stack memory is limited, so this should only be used for reasonably sized tensors. * The allocated memory is automatically freed when the function returns. **Constraints:** * The layout must be fully static (all dimensions known at compile time). * The alignment must be a multiple of the tensor's minimum required alignment. **Parameters:** * ​stack\_alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Memory alignment value for the allocation in bytes. Must be a multiple of the tensor's minimum required alignment. Default is the tensor's natural alignment based on its data type and layout. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new `LayoutTensor` instance with memory allocated on the stack. ### `null` `static null() -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].StackTensorType` Returns a null `LayoutTensor` object. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A null `LayoutTensor` object. ### `to_device_buffer` `to_device_buffer(self, ctx: DeviceContext) -> DeviceBuffer[dtype]` Convert the tensor to a `DeviceBuffer`. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The device context to use. **Returns:** [`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer): A `DeviceBuffer` containing the tensor's data. ### `is_static_shape` `static is_static_shape[idx: Int]() -> Bool` Returns the whether the specified dimension is statically known. Performance: * This is a compile-time operation with no runtime cost when used with static dimensions. Notes: * This is a static method that operates on the tensor's type information, not on a specific tensor instance. **Parameters:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension index to query (0-based). For example, in a 3D tensor with shape \[10, UNKNOWN\_VALUE, 30]: \- `shape[0]()` returns True (first dimension). \- `shape[1]()` returns False (second dimension). \- `shape[2]()` returns True (third dimension). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): The True if the dimension is statically known, False otherwise. ### `shape` `static shape[idx: Int]() -> Int` Returns the size of the tensor along the specified dimension. Provides static access to the tensor's shape information. This method returns the size of a specific dimension without requiring an instance of the tensor, as the shape is part of the tensor's static type information. Performance: * This is a compile-time operation with no runtime cost when used with static dimensions. Notes: * This is a static method that operates on the tensor's type information, not on a specific tensor instance. **Parameters:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension index to query (0-based). For example, in a 3D tensor with shape \[10, 20, 30]: * `shape[0]()` returns 10 (first dimension). * `shape[1]()` returns 20 (second dimension). * `shape[2]()` returns 30 (third dimension). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the tensor along the specified dimension as an integer. ### `get_shape` `get_shape(self) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Get the flattened shape of a LayoutTensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The flattened shape of a LayoutTensor. ### `get_stride` `get_stride(self) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Get the flattened stride of a LayoutTensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The flattened shape of a LayoutTensor. ### `stride` `static stride[idx: Int]() -> Int` Returns the memory stride of the tensor along the specified dimension. Provides static access to the tensor's stride information. The stride represents the number of elements to skip in memory to move one position along a particular dimension. This method returns the stride without requiring an instance of the tensor, as the stride is part of the tensor's static type information. Performance: * This is a compile-time operation with no runtime cost when used with static dimensions. * Understanding stride patterns is crucial for optimizing memory access patterns in performance-critical code. Notes: * Strides depend on the memory layout (row-major, column-major, or custom). * For non-contiguous tensors (e.g., tensor slices), strides may not follow a simple pattern. **Parameters:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension index to query (0-based). For example, in a 2D tensor with shape \[10, 20] and row-major layout: * `stride[0]()` might return 20 (moving one row requires skipping 20 elements). * `stride[1]()` might return 1 (moving one column requires skipping 1 element). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The memory stride of the tensor along the specified dimension as an integer. `stride(self, idx: Int) -> Int` Returns the runtime stride of the tensor along the specified axis. Unlike the static `stride` method, this instance method takes a runtime dimension index. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension index to query (0-based). For example, in a row-major 3D tensor with shape `[10, 20, 30]`: * `stride(0)` returns 600 (first dimension). * `stride(1)` returns 30 (second dimension). * `stride(2)` returns 1 (third dimension). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The dimension of the tensor along the specified axis as an integer. ### `dim` `dim(self, idx: Int) -> Int` Returns the runtime dimension size of the tensor along the specified axis. Unlike the static `dim` method, this instance method takes a runtime dimension index. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension index to query (0-based). For example, in a 3D tensor with shape `[10, 20, 30]`: * `dim(0)` returns 10 (first dimension). * `dim(1)` returns 20 (second dimension). * `dim(2)` returns 30 (third dimension). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The dimension of the tensor along the specified axis as an integer. `dim[idx: Int](self) -> Int` Returns the dimension size of the tensor along the specified axis. Unlike the static `shape` method, this instance method provides access to the tensor's actual dimension sizes. If the dimension is unknown, the runtime layout is used to get the dimension size. Performance: * For static dimensions known at compile time, prefer the static `shape` method when possible for better performance. Notes: * This method works with both static and dynamic dimensions. * For tensors with masked or partial views, this returns the actual size of the view, not the original tensor. **Constraints:** * Only works with tensors that have depth-1 layouts (no nested shapes). **Parameters:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimension index to query (0-based). For example, in a 3D tensor with shape `[10, 20, 30]`: * `dim[0]()` returns 10 (first dimension). * `dim[1]()` returns 20 (second dimension). * `dim[2]()` returns 30 (third dimension). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the tensor along the specified dimension as an integer. ### `coalesce` `coalesce(self) -> LayoutTensor[dtype, coalesce(layout, False), origin, address_space=address_space, element_layout=element_layout]` Creates a tensor with a coalesced memory layout from this tensor. Coalescing a tensor's layout means reorganizing its memory representation to be as contiguous as possible, which can improve memory access patterns and performance. This operation does not move or copy data; it only changes how the same memory is interpreted. Performance: * Coalesced layouts typically provide better cache utilization and memory access patterns. * This operation is zero-cost at runtime as it only changes the layout information, not the actual data. * Particularly beneficial before operations that perform sequential memory access or vectorized operations. Notes: * The coalesced tensor shares the same memory as the original tensor, so modifications to one will affect the other. * The shape of the tensor remains the same, only the stride information is optimized. * For already optimally coalesced tensors, this operation has no effect. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A tensor with the same data but with a coalesced memory layout. The returned tensor has type `LayoutTensor` with the same dtype but with a coalesced layout. ### `tile` `tile[*tile_sizes: Int](self, *tile_coords: Int) -> LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes](), alignment=alignment]` Extract a tile (sub-tensor) from this tensor with specified dimensions and position. Tiling is a fundamental operation for high-performance tensor computations that divides a tensor into smaller blocks for better cache locality and parallelism. This method extracts a specific tile at the given coordinates without copying data. Example: For a 4x4 tensor with values: ``` [1 2 3 4] [2 3 4 5] [5 4 3 2] [1 1 1 1] ``` `tile[2, 2](1, 0)` will extract the tile: ``` [5 4] [1 1] ``` Performance: * Creates a view without copying data, making it very efficient. * Optimized for both static and dynamic layouts with different code paths. * Properly handles edge cases where tiles may be partially outside the tensor. * Maintains stride information for efficient memory access within the tile. Notes: * The resulting tile is a view into the original tensor, so modifications to the tile will affect the original tensor. * For tiles at the edges of the tensor, the actual dimensions may be smaller than the requested tile\_sizes if masking is enabled. * The implementation automatically selects between static and dynamic tiling based on the tensor's layout properties. **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensions of each tile along each axis of the tensor. For example, in a 2D tensor, `tile[32, 32]` creates 32x32 tiles. **Args:** * ​\*tile\_coords ([`Int`](/mojo/stdlib/builtin/int/Int)): The coordinates of the specific tile to extract. For example, `tile[32, 32](1, 2)` extracts the tile at position (1, 2) in the grid of 32x32 tiles. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view into the original tensor representing the specified tile. ### `simd_tile` `simd_tile[tile_size: Int](self, tile_idx: Int) -> LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_size, simd_width_of[dtype]()]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_size, simd_width_of[dtype]()](), alignment=alignment]` Return a SIMD\[dtype] sized tile of size `tile_size` at `tile_idx`. **Parameters:** * ​tile\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the tile along the tiled axis used for vectorization. **Args:** * ​tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the tile to extract along the tiled axis. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A SIMD\[dtype] tile of size `tile_size` at `tile_idx` ### `tile_with_offset` `tile_with_offset[*tile_sizes: Int](self, *tile_coords: Int) -> Tuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes](), alignment=alignment], LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].CornerCoordsType, Scalar[linear_idx_type]]` Similar to `tile`, but also returns the corner coordinates of the tile as well as the offset. **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensions of each tile along each axis of the tensor. **Args:** * ​\*tile\_coords ([`Int`](/mojo/stdlib/builtin/int/Int)): The coordinates of the specific tile to extract. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): A tuple containing: * The extracted tile as a `LayoutTensor`. * The corner coordinates of the tile. * The offset of the tile. ### `tiled_iterator` `tiled_iterator[*tile_sizes: Int, *, axis: Int = 0](self, *tile_coords: Int) -> LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, tile_sizes]()[0], origin, address_space=address_space, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, tile_sizes]()]` Create an iterator that traverses tiles along a specified axis. This method creates an iterator that allows efficient traversal of tiles within a tensor. The iterator starts at the specified tile coordinates and can move along the specified axis, providing access to consecutive tiles. Performance: * Provides efficient sequential access to tiles with good cache locality. * Optimized for both static and dynamic layouts with different code paths. * Maintains stride information for efficient memory access within each tile. * Properly handles edge cases where tiles may be partially outside the tensor. Notes: * The iterator provides views into the original tensor, so modifications through the iterator will affect the original tensor. * For tiles at the edges of the tensor, the actual dimensions may be smaller than the requested tile\_sizes if masking is enabled. * The iterator is not circular by default, meaning it will not wrap around when reaching the end of the tensor along the iteration axis. * The implementation automatically selects between static and dynamic tiling based on the tensor's layout properties. Example: ```mojo var iter = tensor.tiled_iterator[16, 16, axis=0](0, 0) for i in range(num_tiles_along_axis): var tile = iter.get() # Process tile iter.next() ``` **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensions of each tile along each axis of the tensor. For example, in a 2D tensor, `tiled_iterator[32, 32]` creates an iterator over 32x32 tiles. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis along which the iterator will traverse. Default is 0 (first dimension). For example, with axis=0, the iterator will move vertically through tiles. **Args:** * ​\*tile\_coords ([`Int`](/mojo/stdlib/builtin/int/Int)): The starting coordinates of the tile where iteration begins. **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A `LayoutTensorIter` that can be used to traverse tiles along the specified axis. ### `split` `split[count: Int, axis: Int = 0](self) -> StaticTuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, (layout.shape[axis].value() // count), axis]()[0], origin, address_space=address_space, element_layout=element_layout, alignment=alignment], count]` Split the `LayoutTensor` along a axis and return a `StaticTuple` of `LayoutTensor`. **Parameters:** * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of portion to split. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis where the split is applied to. **Returns:** `StaticTuple`: A `StaticTuple` containing `count` `LayoutTensors`, each representing an equal-sized partition of the original tensor along the specified axis. Each partition has the same data type and memory characteristics as the original tensor, but with a reduced size along the split axis. `split[axis: Int = 0, split_alignment: Int = 1](self, count: Int, idx: Int) -> LayoutTensor[dtype, layout.make_shape_unknown[axis](), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Retrieve a specific partition of the tensor after splitting along a specified axis. This method divides the tensor into 'count' partitions along the specified axis and returns the partition at index 'idx'. The partitioning is done with alignment considerations to optimize memory access patterns. Unlike the overloaded split method that returns all partitions, this method returns only a single partition, making it more memory-efficient for cases where only one partition is needed at a time. Notes: * The shape along the split axis becomes unknown at compile time. * Only works with dimensions that have statically known sizes. * The last partition may be smaller than others if the dimension size is not evenly divisible by `count`. * Partition sizes are aligned up to the specified alignment value, which can improve performance for vectorized operations. Performance: * Uses aligned partitioning to improve memory access patterns. * Avoids creating all partitions in memory, reducing memory usage. * Maintains the original tensor's stride information for efficient element access within the partition. **Constraints:** * The dimension being split must have a statically known size. * Cannot split dimensions with unknown or dynamic sizes. **Parameters:** * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis along which to split the tensor. Defaults to 0 (first dimension). * ​split\_alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Memory alignment value for the partition size. Defaults to 1. **Args:** * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of partitions to divide the tensor into. * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the partition to return (0-based). **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A `LayoutTensor` representing the requested partition. ### `distribute` `distribute[threads_layout: Layout, axis: OptionalReg[Int] = None, swizzle: OptionalReg[Swizzle] = None, submode_axis: OptionalReg[Int] = None](self, thread_id: UInt) -> LayoutTensor[dtype, _compute_distribute_layout[layout, threads_layout, axis]()[1], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _distribute_is_masked[layout, threads_layout, axis]() if is_nvidia_gpu() else False]` Distribute tensor workload across multiple threads in a structured pattern. This method partitions a tensor across multiple threads for parallel processing, assigning each thread a specific portion of the tensor. The distribution pattern is determined by the threads\_layout parameter, which defines the logical arrangement of threads. Example: For a 4x4 row-major tensor distributed across 4 threads in a 2x2 row-major grid: * Thread 0 will receive a LayoutTensor with a view into (0,0), (0,2), (2,0), (2,2) of the original tensor. * Thread 1 will receive a LayoutTensor with a view into (0,1), (0,3), (2,1), (2,3) of the original tensor. * Thread 2 will receive a LayoutTensor with a view into (1,0), (1,2), (3,0), (3,2) of the original tensor. * Thread 3 will receive a LayoutTensor with a view into (1,1), (1,3), (3,1), (3,3) of the original tensor. If axis=0 is specified with the same setup: * Thread (0, 0) and Thread (0, 1) would get the same data (top half) * Thread (1, 0) and Thread (1, 1) would get the same data (bottom half) Performance: * Creates a view without copying data, making it very efficient for parallel processing. * The swizzle parameter can significantly improve cache locality and memory access patterns. * Optimized for both static and dynamic layouts with different code paths. Notes: * The resulting tensor is a view into the original tensor, so modifications will affect the original tensor. * For optimal performance, the `threads_layout` should match the hardware's thread organization (e.g., warp/wavefront size and shape). * When using swizzling, carefully consider the memory access patterns to avoid cache thrashing or bank conflicts. * This function is particularly useful for GPU programming where threads are organized in structured grids. **Constraints:** * For dynamic layouts, the shape must be known at runtime and the threads\_layout must be fully static. **Parameters:** * ​threads\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Defines the logical arrangement of threads (e.g., 2x2 grid of 4 threads). This layout determines how the tensor is partitioned. * ​axis ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional. If specified, restricts distribution to only this axis. For example, with axis=0 in a 2D thread layout, threads that differ only in their second coordinate will receive the same data. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional. A function that remaps the distribution pattern to improve memory access patterns or cache locality. * ​submode\_axis ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional. Specifies an axis for specialized distribution modes. **Args:** * ​thread\_id ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The ID of the current thread (0-based). **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view into the original tensor representing the portion assigned to this thread. ### `distribute_with_offset` `distribute_with_offset[threads_layout: Layout, axis: OptionalReg[Int] = None, swizzle: OptionalReg[Swizzle] = None, submode_axis: OptionalReg[Int] = None](self, thread_id: UInt) -> Tuple[LayoutTensor[dtype, _compute_distribute_layout[layout, threads_layout, axis]()[1], origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _distribute_is_masked[layout, threads_layout, axis]() if is_nvidia_gpu() else False], IndexList[threads_layout.rank(), element_type=layout_int_type], Scalar[linear_idx_type]]` Similar to `distribute`, but also returns the corner coordinates of the tile as well as the offset. **Parameters:** * ​threads\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the threads. * ​axis ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The axis to distribute along. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): An optional swizzle function. * ​submode\_axis ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): An optional submode axis. **Args:** * ​thread\_id ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The ID of the current thread (0-based). **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): A tuple containing: * The distributed tensor. * The corner coordinates of the tile. * The offset of the tile. ### `vectorize` `vectorize[*vector_shape: Int](self) -> LayoutTensor[dtype, coalesce(LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, vector_shape]()[1], True), origin, address_space=address_space, element_layout=LayoutTensor._divide_tiles[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment, vector_shape]()[0], layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Reshape a tensor into a vectorized form for efficient SIMD operations. This method transforms the tensor's logical layout to enable efficient vectorized processing, treating blocks of elements as vector units. The transformation is particularly useful for SIMD (Single Instruction Multiple Data) operations and hardware acceleration. Example: For a 16x16 tensor, `vectorize[4, 4]` will produce a 4x4 tensor where each element represents a 4x4 block from the original tensor. Performance: * Creates a view without copying data, making it very efficient. * Enables hardware-accelerated vector operations on blocks of data. * Improves cache locality by grouping related elements together. * Particularly beneficial for operations that can leverage SIMD instructions. Notes: * The tensor dimensions must be divisible by the corresponding vector dimensions. * For dimensions with unknown size, the corresponding vector dimension must be 1. * The resulting tensor has the same data but a different logical organization. * Modifications to the vectorized tensor affect the original tensor. * This transformation is particularly useful for GPU and vector processor optimizations. **Constraints:** * Each tensor dimension must be divisible by the corresponding vector dimension. * Vector dimensions must be smaller than or equal to the corresponding tensor dimensions. * For dimensions with unknown size, the vector dimension must be 1. **Parameters:** * ​\*vector\_shape ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensions of each vector unit along each axis of the tensor. or example, in a 2D tensor, `vectorize[4, 4]` treats 4x4 blocks as vector units. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with a vectorized layout, where each element in the resulting tensor represents a vector of elements from the original tensor. `vectorize(self) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].SIMDVectorizedType` Return a SIMD\[dtype] vectorized view of this tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A `Self.VectorizedType[1, simd_width_of[Self.dtype]()]` view whose width equals the SIMD width for the tensor's dtype. ### `slice` `slice[d0_slice: Slice, d1_slice: Slice](self) -> LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, d1_slice), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Extract a slice from a rank-2 tensor using slice objects. This method creates a view into a subset of the tensor defined by the slice specifications for each dimension. The slice is a continuous region of the tensor with no gaps (step size must be 1). Example: For a 4x4 tensor, `t` with values: ``` [1 2 3 4] [5 6 7 8] [9 10 11 12] [13 14 15 16] ``` ```mojo t.slice[Slice(1, 3), Slice(0, 2)] ``` will extract: ``` [5 6] [9 10] ``` Performance: * Creates a view without copying data, making it very efficient. * Maintains the original tensor's stride information for efficient memory access. * Zero-cost abstraction at runtime when used with compile-time constant slices. Notes: * The slice is a view into the original tensor, so modifications to the slice will affect the original tensor. * Only supports rank-2 tensors. For higher-rank tensors, use the overloaded version with slice indices. * The step size must be 1 (no gaps allowed in the slice). * Slice bounds are not checked at runtime; accessing out-of-bounds indices will result in undefined behavior. **Constraints:** * Only works with rank-2 tensors. **Parameters:** * ​d0\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the first dimension (rows). Defines the start and end indices for the slice along this dimension. * ​d1\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the second dimension (columns). Defines the start and end indices for the slice along this dimension. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view into the original tensor representing the specified slice. `slice[d0_slice: Slice, d1_slice: Slice, slice_indices: IndexList[2], __offset_dims: Int = (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)](self, offsets: IndexList[__offset_dims]) -> LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, d1_slice, slice_indices.__getitem__[2, DType.int64, Int](0), slice_indices.__getitem__[2, DType.int64, Int](1)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Extract a 2D slice from a higher-rank tensor at specific indices. This method creates a view into a 2D subset of a higher-rank tensor: Selecting two dimensions to slice using the slice\_indices parameter. Applying slice specifications to those dimensions. Using fixed offsets for all other dimensions. Example: Given a 3x4x5 tensor, `t`, the following example extracts a 2x2 slice from dimensions 0 and 2, with dimension 1 fixed at index 1. ```mojo t.slice = t.slice[Slice(1, 3), Slice(0, 2), IndexList[2](0, 2)](1) ``` Performance: * Creates a view without copying data, making it very efficient. * Maintains the original tensor's stride information for efficient memory access. * Zero-cost abstraction at runtime when used with compile-time constant slices. Notes: * The slice is a view into the original tensor, so modifications to the slice will affect the original tensor. * The slice indices must be ordered (e.g., \[0, 2] is valid, \[2, 0] is not). * The step size must be 1 (no gaps allowed in the slice). * Slice bounds are not checked at runtime; accessing out-of-bounds indices will result in undefined behavior. **Constraints:** * Slice step size must be 1 (no gaps). * Slice indices must be ordered (ascending). * Tensor rank must be at least 2. **Parameters:** * ​d0\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the first selected dimension. * ​d1\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the second selected dimension. * ​slice\_indices ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Indices of the two dimensions to slice (must be ordered). * ​\_\_offset\_dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Internal parameter representing number of fixed dimensions. **Args:** * ​offsets ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Fixed index values for all dimensions not being sliced. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A 2D view into the original tensor representing the specified slice. ### `slice_1d` `slice_1d[d0_slice: Slice, slice_indices: IndexList[1], __offset_dims: Int = (LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 1)](self, offsets: IndexList[__offset_dims]) -> LayoutTensor[dtype, LayoutTensor._compute_slice_layout[mut, dtype, layout, origin, address_space, element_layout, layout_int_type, linear_idx_type, masked, alignment](d0_slice, slice_indices.__getitem__[1, DType.int64, Int](0)), origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Extract a 1D slice from a higher-rank tensor at a specific index. This method creates a view into a 1D subset of a higher-rank tensor by: 1. Selecting one dimension to slice using the slice\_indices parameter 2. Applying a slice specification to that dimension 3. Using fixed offsets for all other dimensions Example: For a 3x4x5 tensor, `t`, the following example extracts a 1D slice from dimension 0, with dimensions 1 and 2 fixed at indices 1 and 2: ```mojo t.slice_1d[Slice(1, 3), IndexList[1](0)](1, 2) ``` Performance: * Creates a view without copying data, making it very efficient. * Maintains the original tensor's stride information for efficient memory access. * Zero-cost abstraction at runtime when used with compile-time constant slices. Notes: * The slice is a view into the original tensor, so modifications to the slice will affect the original tensor. * The step size must be 1 (no gaps allowed in the slice). * Slice bounds are not checked at runtime; accessing out-of-bounds indices will result in undefined behavior. * This function exists as a workaround for compiler limitations with overloading. **Constraints:** * Slice step size must be 1 (no gaps). * Tensor rank must be at least 1. **Parameters:** * ​d0\_slice ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): Slice specification for the selected dimension. * ​slice\_indices ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Index of the dimension to slice. * ​\_\_offset\_dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Internal parameter representing number of fixed dimensions. **Args:** * ​offsets ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Fixed index values for all dimensions not being sliced. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A 1D view into the original tensor representing the specified slice. ### `transpose` `transpose(self) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].TransposeType` Create a transposed view of a tensor. This method creates a view of the tensor with its dimensions swapped, effectively converting rows to columns and columns to rows. The transposition is performed without copying data, by adjusting the tensor's layout information. Example: For a 2x3 tensor with values: ``` [1 2 3] [4 5 6] ``` `transpose()` will produce a 3x2 tensor: ``` [1 4] [2 5] [3 6] ``` Performance: * Creates a view without copying data, making it very efficient. * The operation is zero-cost at runtime as it only changes the layout information. * Memory access patterns may be less efficient in the transposed view due to non-contiguous memory access, especially for row-major storage. Notes: * The transposed tensor shares the same memory as the original tensor, so modifications to one will affect the other. * For optimal performance when repeatedly accessing the transposed data, consider creating a physical copy with the transposed layout. * Transpose only works with statically known shapes. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with dimensions transposed (rows become columns and vice versa). ### `reshape` `reshape[dst_layout: Layout](self) -> LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Create a view of the tensor with a different shape. This method creates a view of the tensor with a new shape, without changing the underlying data. The total number of elements must remain the same. Example: Given a 2x6 row-major tensor, `reshape[Layout.col_major(3, 4)]()` produces a 3x4 tensor with the same elements in column-major order. Performance: * Creates a view without copying data, making it very efficient. * The operation is zero-cost at runtime as it only changes the layout information. * Memory access patterns may change, potentially affecting performance depending on the original and target layouts. Notes: * The reshaped tensor shares the same memory as the original tensor, so modifications to one will affect the other. * The total number of elements must remain the same after reshaping. * The reshape operation assumes a row-major (C-style) memory layout. * For tensors with complex strides or non-contiguous memory, reshaping may not produce the expected results. * Masked tensors cannot be reshaped. **Constraints:** * Cannot reshape masked tensors. * The total number of elements must be the same in both layouts. **Parameters:** * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout for the reshaped tensor. Must have the same total number of elements as the original tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with the new shape specified by dst\_layout. `reshape[dst_layout: Layout](self, runtime_layout: RuntimeLayout[dst_layout]) -> LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Create a view of the tensor with a different shape. This method creates a view of the tensor with a new shape, without changing the underlying data. The total number of elements must remain the same. Example: Given a 2x6 row-major tensor, `reshape[Layout.col_major(3, 4)]()` produces a 3x4 tensor with the same elements in column-major order. Performance: * Creates a view without copying data, making it very efficient. * The operation is zero-cost at runtime as it only changes the layout information. * Memory access patterns may change, potentially affecting performance depending on the original and target layouts. Notes: * The reshaped tensor shares the same memory as the original tensor, so modifications to one will affect the other. * The total number of elements must remain the same after reshaping. * The reshape operation assumes a row-major (C-style) memory layout. * For tensors with complex strides or non-contiguous memory, reshaping may not produce the expected results. * Masked tensors cannot be reshaped. **Constraints:** * Cannot reshape masked tensors. * The total number of elements must be the same in both layouts. **Parameters:** * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout for the reshaped tensor. Must have the same total number of elements as the original tensor. **Args:** * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The target RuntimeLayout for the reshaped tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with the new shape specified by dst\_layout. ### `flatten` `flatten(self) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].FlattenedType` Convert a LayoutTensor to a flattened dynamic layout. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A LayoutTensor to a flattened dynamic layout. ### `composition` `composition[rhs_layout: Layout, dst_layout: Layout = composition(layout, rhs_layout)](self) -> LayoutTensor[dtype, dst_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Create a view of the tensor with a composed layout. This method creates a view of the tensor with a new layout that is the composition of the original layout with another layout. Layout composition allows for complex transformations of the tensor's logical structure without copying data. Example: For a 4x4 tensor with a standard row-major layout, composing with a layout that represents a 2x2 tiling would result in a tensor that logically views the data as 2x2 blocks. Performance: * Creates a view without copying data, making it very efficient. * The operation is zero-cost at runtime as it only changes the layout information. * Can be used to optimize memory access patterns for specific algorithms. Notes: * The composed tensor shares the same memory as the original tensor, so modifications to one will affect the other. * Layout composition is a powerful tool for expressing complex data transformations like tiling, transposition, and reshaping in a unified framework. * Understanding the mathematical properties of layout composition is important for correctly using this function. **Constraints:** * The layouts must be compatible for composition. * The total number of elements must remain the same after composition. **Parameters:** * ​rhs\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout to compose with the tensor's current layout. * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The resulting layout after composition. Defaults to the composition of the tensor's layout with rhs\_layout. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A view of the tensor with the composed layout. ### `distance` `distance(self, addr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin]) -> Scalar[linear_idx_type]` Calculate the element-wise distance between this tensor's pointer and another pointer. This method computes the number of elements (not bytes) between the tensor's pointer and the provided address. This is useful for determining offsets within a larger memory allocation or for pointer arithmetic operations. Example: If `tensor.ptr` points to an element at index 100 in a buffer, and `addr` points to element at index 50, then `distance(addr)` returns 50. Performance: * This is a lightweight operation that only involves pointer arithmetic. * The operation is optimized based on the address space, using smaller integer types for shared memory to improve efficiency. Notes: * The distance is calculated in elements, not bytes. * The result can be positive or negative depending on the relative positions of the pointers. * This function is particularly useful for GPU programming where understanding memory offsets is critical for performance. * Care should be taken when using this with pointers from different allocations, as the result would be meaningless. **Args:** * ​addr (`LegacyUnsafePointer`): The target pointer to calculate the distance to. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The number of elements between this tensor's pointer and the provided address. The result is of type `_uint_dtype`. `distance[_layout: Layout, _uint_dtype: DType = _get_unsigned_type(_layout, address_space)](self, src: LayoutTensor[dtype, _layout, origin, address_space=address_space]) -> Scalar[_uint_dtype]` Calculate the element-wise distance between this tensor and another tensor. This method computes the number of elements (not bytes) between this tensor's pointer and another tensor's pointer. This is useful for determining the relative positions of tensors within a larger memory allocation. Example: If tensor1 points to element at index 100 in a buffer, and tensor2 points to element at index 50, then `tensor1.distance(tensor2)` would return 50. Performance: * This is a lightweight operation that only involves pointer arithmetic. * The operation is optimized based on the address space and layout, using appropriate integer types for efficiency. Notes: * The distance is calculated in elements, not bytes. * The result can be positive or negative depending on the relative positions of the tensors. * This function is particularly useful for GPU programming where understanding memory offsets is critical for performance. * Both tensors must be in the same address space for the result to be meaningful. * This overload is more type-safe than the pointer-based version as it ensures the tensors have compatible data types and address spaces. **Parameters:** * ​\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the source tensor. * ​\_uint\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The unsigned integer type to use for the result. Automatically determined based on the layout and address space. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor to calculate the distance to. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The number of elements between this tensor's pointer and the source tensor's pointer. The result is of type \_uint\_dtype. ### `copy_from` `copy_from(self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], other: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Copy data from another tensor to this tensor. This method performs an element-by-element copy from the source tensor to this tensor, respecting the layouts of both tensors. The copy operation handles different memory layouts correctly, ensuring that elements are copied to their proper positions regardless of how the data is arranged in memory. * Both tensors must have statically known shapes. * The total number of elements must be the same in both tensors. * The element sizes must match between the tensors. Example: ```mojo from layout import LayoutTensor, Layout var src_storage = InlineArray[Float32, 2 * 3](uninitialized=True) var dst_storage = InlineArray[Float32, 3 * 2](uninitialized=True) var src = LayoutTensor[ DType.float32, Layout([2, 3]), ](src_storage).fill(1.0) var dst = LayoutTensor[ DType.float32, Layout([3, 2]), ](dst_storage) dst.copy_from(src) # Copies all elements from src to dst ``` Performance: * Performs element-by-element copying, which may be less efficient than vectorized or bulk memory operations. * The copy respects the memory layout of both tensors, which may involve non-contiguous memory access patterns. * For optimal performance with large tensors, consider using specialized copy functions that can leverage hardware acceleration. Notes: * Both tensors must have statically known shapes. * The total number of elements must be the same in both tensors. * The element sizes must match between the tensors. * This function handles different memory layouts correctly, making it suitable for copying between tensors with different shapes or strides. * The copy is performed element by element, not as a bulk memory copy. **Args:** * ​other ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor to copy data from. Must have the same total number of elements as this tensor. ### `copy_from_async` `copy_from_async[is_masked: Bool = False, swizzle: OptionalReg[Swizzle] = None, fill: Fill = Fill.NONE, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](self, src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_idx_bound: Scalar[linear_idx_type] = 0, base_offset: Scalar[linear_idx_type] = 0)` Asynchronously copy data from another tensor to this tensor using GPU hardware. This method performs an asynchronous copy from the source tensor to this tensor using GPU hardware acceleration. It's specifically designed for copying data from global memory to shared memory in GPU kernels, leveraging hardware-specific asynchronous copy mechanisms for improved performance. For optimal performance, you need to arrange the copy correctly. Use the [`distribute()`](/mojo/kernels/layout/layout_tensor/LayoutTensor/#distribute) method to create thread-local fragments of the source and destination tensors, assigning each thread one or more elements to copy. Optionally, use the \[`vectorize()`]\((/mojo/kernels/layout/layout\_tensor/LayoutTensor/#vectorize) method to get vectorized views of both tensors before calling `distribute()`. This allows each thread to copy multiple elements of the tensor. For example: ```mojo var fragment = tensor.vectorize[1, simd_width]().distribute[ thread_layout ](thread_id) ``` The copy operation is asynchronous, so you must call [`async_copy_wait_all()`](/mojo/stdlib/gpu/memory/async_copy_wait_all/) or [`async_copy_wait_group()`](/mojo/stdlib/gpu/memory/async_copy_wait_group/) to ensure the copy has completed before using the data. Example: ```mojo from layout import LayoutTensor, Layout from gpu import thread_idx, block_idx, block_dim from gpu.memory import async_copy_wait_all comptime dtype = DType.float32 comptime in_size = 128 comptime block_size = 16 num_blocks = in_size // block_size comptime input_layout = Layout.row_major(in_size, in_size) fn kernel(tensor: LayoutTensor[dtype, input_layout, MutAnyOrigin]): # extract a tile from the input tensor. var global_tile = tensor.tile[block_size, block_size](block_idx.x, block_idx.y) # allocate a shared memory tile comptime tile_layout = Layout.row_major(block_size, block_size) var shared_tile = LayoutTensor[ dtype, tile_layout, MutAnyOrigin, address_space = AddressSpace.SHARED, ].stack_allocation() # Create per-thread tile fragments for copying var tid = thread_idx.y + thread_idx.x * block_dim.x comptime thread_layout = Layout.row_major(block_size, block_size) var global_fragment = global_tile.distribute[thread_layout](tid) var shared_fragment = shared_tile.distribute[thread_layout](tid) # async copy to shared memory shared_fragment.copy_from_async(global_fragment) async_copy_wait_all() # ... do something with the shared tile ``` Performance: * Supports vectorized copies for 4, 8, or 16-byte elements for better throughput. * Can bypass L1 cache with appropriate eviction policies for specific access patterns. * Swizzling can improve memory access patterns and reduce bank conflicts. Notes: * For vectorized copies, both tensors must have contiguous element layouts. * Asynchronous copies allow computation to overlap with memory transfers. * A synchronization barrier is required before using the copied data. **Constraints:** * Destination must be in shared memory. * Source and destination data types must match. * Element size must be 4, 8, or 16 bytes. * Destination tensor must have a static layout. **Parameters:** * ​is\_masked ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to perform a masked copy, where elements outside the `src_idx_bound` are not copied or filled with zeros. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns. * ​fill ([`Fill`](/mojo/stdlib/gpu/memory/memory/Fill)): Fill policy for elements that are not copied (only used with masked copies). * ​eviction\_policy ([`CacheEviction`](/mojo/stdlib/gpu/memory/memory/CacheEviction)): Cache eviction policy for the source data. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor to copy data from. * ​src\_idx\_bound ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): For masked copies, the upper bound index for valid source elements. * ​base\_offset ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Base offset for swizzling calculations. ### `fill` `fill[*, use_runtime_layout: Bool = layout.all_dims_known().__bool__().__invert__() if (not layout.all_dims_known()._mlir_value) else (layout.size() > 2048)](self: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], val: Scalar[dtype]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Fill the entire tensor with a single value. This method sets all elements of the tensor to the specified value. It works with both statically and dynamically shaped tensors. For statically known layouts, the fill operation is unrolled at compile time. For dynamic layouts, a runtime loop is used. No vectorization is applied, so performance may be suboptimal for large tensors. Consider using hardware-specific fill operations for better performance with large tensors. This method can be used with tensors of any rank and shape. The fill operation respects the tensor's layout, filling all elements regardless of how they are arranged in memory. For tensors with `element_layout`, all elements within each logical element are filled with the same value. Example: ```mojo from layout import Layout, LayoutTensor def main(): var storage = InlineArray[Float32, 3 * 4](uninitialized=True) var tensor = LayoutTensor[ DType.float32, Layout([3, 4]), ](storage).fill(0.0) print(tensor) ``` If not using method chaining, you can either reassign the result to the tensor variable, or assign the result to the discard pattern (`_`) to avoid warnings about an unused value: ```mojo tensor = tensor.fill(0.0) # or _ = tensor.fill(0.0) ``` **Parameters:** * ​use\_runtime\_layout ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to use the runtime layout for filling. This parameter is defaulted to `True` if the layout is not statically known. If loop bounds are too large, it's better to use the runtime layout to avoid long compilation time. **Args:** * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value to fill the tensor with. Must be of the same data type as the tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The tensor itself (self), allowing for method chaining. ### `__str__` `__str__(self) -> String` Convert the tensor to a string representation. This method converts the tensor to a human-readable string representation by writing its contents to a string. It delegates to the `write_to` method which formats the tensor appropriately based on its rank and shape. **Returns:** `String`: A string representation of the tensor. ### `write_to` `write_to(self, mut writer: T)` Format and write the tensor's contents to a writer. This method formats the tensor's contents and writes them to the provided writer. For 2D tensors, it formats the output in a 2D grid. For tensors of other ranks, it prints all values in column-major coordinate order. Example: ```mojo from layout import Layout, LayoutTensor def main(): var storage = InlineArray[Float32, 2 * 3](uninitialized=True) var tensor = LayoutTensor[ DType.float32, Layout([2, 3]), ](storage).fill(1.0) print(tensor) # Internally calls `write_to` with a StringWriter ``` Output for a 2x3 tensor: ``` [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]] ``` Notes: * For 2D tensors, the output is formatted as a 2D grid with rows and columns. * For tensors of other ranks, values are printed in column-major coordinate order. * Empty tensors (size 0) produce no output. * This method is used by the `__str__` method to convert the tensor to a string. * The formatting is designed for human readability rather than parsing. * For large tensors, the output may be truncated to avoid excessive output. **Args:** * ​writer (`T`): The writer instance to write the formatted output to.
--- ## LayoutTensorIter
`@register_passable(trivial)` `struct LayoutTensorIter[mut: Bool, //, dtype: DType, layout: Layout, origin: Origin[mut], /, *, address_space: AddressSpace = AddressSpace.GENERIC, alignment: Int = align_of[dtype](), circular: Bool = False, axis: OptionalReg[Int] = None, layout_int_type: DType = _get_index_type(address_space), linear_idx_type: DType = _get_index_type(address_space), masked: Bool = False]` Iterator for traversing a memory buffer with a specific layout. `LayoutTensorIter` provides a way to iterate through memory according to a specific layout pattern, constructing layout tensors at each position. This enables efficient traversal of multi-dimensional data structures with custom memory layouts. Notes: The returned layout tensor is NOT vectorized. Users should explicitly vectorize if needed for performance-critical operations. ## Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterator allows mutation of the underlying data. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the tensor elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The memory layout pattern to follow during iteration. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): Origin tracking for memory safety. * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The memory address space (`GLOBAL`, `SHARED`, etc.). * ​alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Memory alignment requirement for the data. * ​circular ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether iteration wraps around at boundaries. * ​axis ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional axis for dimension-specific operations. * ​layout\_int\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Integer type used for layout indices. * ​linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Integer type used for indexing into memory. * ​masked ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to apply bounds masking during iteration. ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin]`): Pointer to the memory region being iterated, with appropriate type and memory attributes. * ​offset (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].linear_uint_type`): Current offset from the base pointer, representing the iterator's position in memory. * ​stride (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].linear_uint_type`): Step size between consecutive elements or blocks in memory during iteration. * ​bound (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].linear_uint_type`): Upper bound of the memory region, limiting the iteration range. * ​runtime\_layout (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].RuntimeLayoutType`): Runtime representation of the layout pattern used for mapping logical indices to memory locations. * ​dimension\_bound (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].layout_uint_type`): Boundary value for the current dimension when iterating along a specific axis. * ​idx (`LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].linear_uint_type`): Current logical index position within the iteration sequence. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BitcasType` `comptime BitcasType[new_type: DType, *, address_space: AddressSpace = address_space, alignment: Int = alignment] = LayoutTensorIter[new_type, layout, origin, address_space=address_space, alignment=alignment, circular=circular, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for bitcast iterator types. #### Parameters * ​new\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The target data type. * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The target address space. * ​alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): The target memory alignment. ### `layout_uint_type` `comptime layout_uint_type = Scalar[_unsigned_integral_type_of[layout_int_type]()]` The unsigned integer type used for layout, based on layout and address space. ### `LayoutTensorType` `comptime LayoutTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` The LayoutTensor type returned by this iterator. ### `linear_uint_type` `comptime linear_uint_type = Scalar[_unsigned_integral_type_of[linear_idx_type]()]` The unsigned integer type used for indexing into memory. ### `ReshapeType` `comptime ReshapeType[dst_layout: Layout] = LayoutTensorIter[dtype, dst_layout, origin, address_space=address_space, alignment=alignment, circular=circular, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Type alias for reshaped iterator types. #### Parameters * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout for the reshaped iterator. ### `RuntimeLayoutType` `comptime RuntimeLayoutType = RuntimeLayout[layout, element_type=layout_int_type, linear_idx_type=linear_idx_type]` Type alias for the runtime layout. ## Methods ### `__init__` `__init__() -> Self` Initialize an empty iterator. Creates a default iterator with zero values, typically used as a placeholder or default value. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin], bound: Scalar[_unsigned_integral_type_of[linear_idx_type]()], stride: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = layout.size(), offset: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 0) -> Self` Initialize an iterator with a pointer and basic parameters. Creates an iterator for a memory region with the specified bounds and stride. **Constraints:** The layout must have all dimensions known at compile time. **Args:** * ​ptr (`LegacyUnsafePointer`): Pointer to the beginning of the memory region. * ​bound ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Upper bound of the memory region. * ​stride ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Step size between consecutive elements (defaults to layout size). * ​offset ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Initial offset from the base pointer. `__init__(ptr: LegacyUnsafePointer[Scalar[dtype], address_space=address_space, mut=mut, origin=origin], bound: Scalar[_unsigned_integral_type_of[linear_idx_type]()], runtime_layout: RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type], stride: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = layout.size() if layout.all_dims_known() else -1, offset: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 0, dimension_bound: Scalar[_unsigned_integral_type_of[layout_int_type]()] = 0, idx: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 0) -> Self` Initialize an iterator with a runtime layout. Creates an iterator with a runtime-determined layout, allowing for more flexible memory traversal patterns. **Constraints:** The runtime layout must have the same bitwidth as specified for the iterator. Circular iteration is not supported when an axis is defined. **Args:** * ​ptr (`LegacyUnsafePointer`): Pointer to the beginning of the memory region. * ​bound ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Upper bound of the memory region. * ​runtime\_layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): Layout determined at runtime. * ​stride ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Step size between consecutive elements. * ​offset ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Initial offset from the base pointer. * ​dimension\_bound ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Bound for the specified dimension when using masked iteration. * ​idx ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Initial index position. ### `__getitem__` `__getitem__(self) -> LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].LayoutTensorType` Get the layout tensor at the current iterator position. Operator overload that returns a layout tensor representing the data at the current position of the iterator. **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A layout tensor at the current iterator position. ### `__iadd__` `__iadd__[T: Intable](mut self, rhs: T)` Increment the iterator by an integer value. Advances the iterator by the specified number of positions. Notes: This function is unsafe. It omits bound checking for performance reasons. Caller must ensure the index doesn't go out-of-bounds. **Parameters:** * ​T ([`Intable`](/mojo/stdlib/builtin/int/Intable)): A type that can be converted to an integer. **Args:** * ​rhs (`T`): The number of positions to advance. `__iadd__(mut self, rhs: Scalar[_unsigned_integral_type_of[linear_idx_type]()])` Increment the iterator by a uint value. Advances the iterator by the specified number of positions. Notes: This function is unsafe. It omits bound checking for performance reasons. Caller must ensure the index doesn't go out-of-bounds. **Args:** * ​rhs ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The number of positions to advance. ### `get` `get(self) -> LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked].LayoutTensorType` Get the layout tensor at the current iterator position. Returns a layout tensor representing the data at the current position of the iterator. **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A tensor view at the current iterator position with the same type, layout, and memory characteristics as specified by the output parameter. ### `next` `next[T: Intable](self, rhs: T) -> Self` Return an iterator pointing to a position ahead by rhs steps. Creates a new iterator that points rhs positions ahead of the current one. **Parameters:** * ​T ([`Intable`](/mojo/stdlib/builtin/int/Intable)): An integer-convertible type for the step size. **Args:** * ​rhs (`T`): The number of positions to advance. **Returns:** `Self`: A new iterator pointing to the advanced position. `next(self, rhs: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 1) -> Self` Return an iterator pointing to a position ahead by rhs steps. Creates a new iterator that points rhs positions ahead of the current one. **Args:** * ​rhs ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The number of positions to advance (defaults to 1). **Returns:** `Self`: A new iterator pointing to the advanced position. ### `next_unsafe` `next_unsafe(self, rhs: Scalar[_unsigned_integral_type_of[linear_idx_type]()] = 1) -> Self` Return an iterator pointing to a position ahead by rhs steps (unsafe version). Creates a new iterator that points rhs positions ahead of the current one. This is an unsafe version that omits certain checks for performance. **Constraints:** Cannot be used with masked iterators. User must ensure rhs < bound / stride. **Args:** * ​rhs ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The number of positions to advance (defaults to 1). **Returns:** `Self`: A new iterator pointing to the advanced position. ### `reshape` `reshape[dst_layout: Layout](self) -> LayoutTensorIter[dtype, dst_layout, origin, address_space=address_space, alignment=alignment, circular=circular, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Reshape the iterator to a new layout. This method creates a new iterator with a different layout while preserving the underlying data. The new layout must have the same total size as the original. **Constraints:** * The destination layout must have the same total size as the original. * Both layouts must be contiguous. * Both layouts must have compile-time known dimensions. **Parameters:** * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The target layout to reshape to. **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A new iterator with the specified layout. ### `bitcast` `bitcast[new_type: DType, *, target_address_space: AddressSpace = address_space, target_alignment: Int = alignment](self) -> LayoutTensorIter[new_type, layout, origin, address_space=address_space, alignment=alignment, circular=circular, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked]` Reinterpret the iterator's underlying pointer as a different data type. This method performs a bitcast operation, allowing you to view the same memory location as a different data type without copying or converting the data. **Parameters:** * ​new\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The target data type to cast to. * ​target\_address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The memory address space for the new iterator (defaults to current). * ​target\_alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Memory alignment requirement for the new iterator (defaults to current). **Returns:** [`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter): A new LayoutTensorIter with the same layout but different data type.
--- ## ThreadScope
`@register_passable(trivial)` `struct ThreadScope` Represents the scope of thread operations in GPU programming. This struct defines the scope at which thread operations are performed, particularly for operations like tensor distribution and synchronization. It provides two main scopes: `BLOCK` and `WARP`, which correspond to different levels of thread grouping in GPU programming models. Example: ```mojo from layout.layout_tensor import copy_dram_to_sram, ThreadScope # Distribute tensor at block level (all threads in block participate) copy_dram_to_sram[layout, thread_scope=ThreadScope.BLOCK](dst, src) # Distribute tensor at warp level (only threads in same warp participate) copy_dram_to_sram[layout, thread_scope=ThreadScope.WARP](dst, src) ``` Performance: * WARP scope operations typically have lower synchronization overhead than BLOCK scope operations. * BLOCK scope operations allow coordination across all threads in a block, which is necessary for certain algorithms. * The choice of scope can significantly impact performance and correctness of parallel algorithms. Notes: * The appropriate scope depends on the specific algorithm and hardware. * WARP scope operations may be more efficient for operations that only require coordination within a warp. * BLOCK scope operations are necessary when threads from different warps need to coordinate. * The actual size of a warp or block is hardware-dependent. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BLOCK` `comptime BLOCK = ThreadScope(0)` Represents operations at the thread block level, where all threads in a block participate. ### `WARP` `comptime WARP = ThreadScope(1)` Represents operations at the warp level, where only threads within the same warp participate. ## Methods ### `__init__` `__init__(value: Int) -> Self` Initialize a `ThreadScope` with the given integer value. **Args:** * ​value ([`Int`](/mojo/stdlib/builtin/int/Int)): An integer representing the thread scope (0 for `BLOCK`, 1 for `WARP`). ### `__eq__` `__eq__(self, other: Self) -> Bool` Compare two `ThreadScope` objects for equality. **Args:** * ​other (`Self`): Another `ThreadScope` object to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the thread scopes are equal, False otherwise. ### `__ne__` `__ne__(self, other: Self) -> Bool` Compare two `ThreadScope` objects for inequality. **Args:** * ​other (`Self`): Another `ThreadScope` object to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the thread scopes are not equal, False otherwise. ### `__str__` `__str__(self) -> String` Convert the `ThreadScope` to a human-readable string representation. Aborts: If the thread scope has an invalid value. **Returns:** `String`: A string representation of the thread scope ("BLOCK" or "WARP"). ### `__int__` `__int__(self) -> Int` Convert the `ThreadScope` to an integer value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integer value of the thread scope (0 for BLOCK, 1 for WARP).
--- ## copy_dram_to_local
`copy_dram_to_local[src_thread_layout: Layout, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1, cache_policy: CacheOperation = CacheOperation.ALWAYS](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_base: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], offset: OptionalReg[UInt] = None)` Efficiently copy data from global memory (DRAM) to registers for AMD GPUs. This function implements an optimized memory transfer operation specifically for AMD GPU architectures. It utilizes the hardware's buffer\_load intrinsic to efficiently transfer data from global memory to registers while handling bounds checking. The function distributes the copy operation across multiple threads for maximum throughput. Notes: * The offset calculation method significantly impacts performance. Current implementation optimizes for throughput over flexibility. * This function is particularly useful for prefetching data into registers before performing computations, reducing memory access latency. **Constraints:** * Only supported on AMD GPUs. * The destination element layout size must match the SIMD width. * Source fragments must be rank 2 with known dimensions. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the source tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of thread\_layout. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. * ​cache\_policy ([`CacheOperation`](/mojo/stdlib/gpu/memory/memory/CacheOperation)): The cache policy to use for the copy operation. Defaults to `CacheOperation.ALWAYS`. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in register memory (LOCAL address space). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in global memory (DRAM) to be copied. * ​src\_base ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The original global memory tensor from which src is derived. This is used to construct the buffer struct required by AMD's `buffer_load` intrinsic. * ​offset ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The offset in the global memory. `copy_dram_to_local[src_thread_layout: Layout, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1, cache_policy: CacheOperation = CacheOperation.ALWAYS](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], bounds: UInt32)` Efficiently copy data from global memory (DRAM) to registers for AMD GPUs. This function implements an optimized memory transfer operation specifically for AMD GPU architectures. It utilizes the hardware's buffer\_load intrinsic to efficiently transfer data from global memory to registers while handling bounds checking. The function distributes the copy operation across multiple threads for maximum throughput. Notes: * The offset calculation method significantly impacts performance. Current implementation optimizes for throughput over flexibility. * This function is particularly useful for prefetching data into registers before performing computations, reducing memory access latency. **Constraints:** * Only supported on AMD GPUs. * The destination element layout size must match the SIMD width. * Source fragments must be rank 2 with known dimensions. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the source tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of thread\_layout. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. * ​cache\_policy ([`CacheOperation`](/mojo/stdlib/gpu/memory/memory/CacheOperation)): The cache policy to use for the copy operation. Defaults to `CacheOperation.ALWAYS`. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in register memory (LOCAL address space). * ​src\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): The source tensor iterator. * ​bounds ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Bounds of the buffer, based on the ptr of the src\_iter. `copy_dram_to_local[src_thread_layout: Layout, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Efficiently copy data from global memory (DRAM) to registers. This function implements an optimized memory transfer operation from global memory to register memory. It distributes the copy operation across multiple threads for maximum throughput while handling bounds checking for safety. **Constraints:** * The source tensor must be in GLOBAL address space (DRAM). * The destination tensor must be in LOCAL address space (registers). * Both tensors must have compatible data types. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the source tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of thread\_layout. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in register memory (LOCAL address space). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in global memory (DRAM).
--- ## copy_dram_to_sram
`copy_dram_to_sram[src_thread_layout: Layout, dst_thread_layout: Layout = src_thread_layout, swizzle: OptionalReg[Swizzle] = None, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from DRAM (global memory) to SRAM (shared memory) in a GPU context. This function performs a synchronous copy operation from global memory (DRAM) to shared memory (SRAM) in a GPU context, distributing the workload across multiple threads for parallel execution. It uses thread affinity mapping to ensure efficient work distribution and supports vectorized memory operations for optimal performance. Performance: * Distributes the copy workload across multiple threads for parallel execution. * Supports vectorized loads and stores for better memory throughput. * Can use swizzling to optimize memory access patterns and reduce bank conflicts. * Thread affinity mapping ensures efficient work distribution. * For masked tensors, performs bounds checking to handle edge cases correctly. Notes: * The source tensor must be in GENERIC or GLOBAL address space (DRAM). * The destination tensor must be in SHARED address space (SRAM). * Both tensors must have the same data type. * This function is synchronous, meaning all threads must complete their copy operations before proceeding. * For optimal performance, the thread layouts should match the memory access patterns of the tensors. * This function is particularly useful in GPU kernels for loading data from global memory to shared memory for faster access. **Constraints:** * Source and destination tensors must have the same data type. * Source tensor must be in GENERIC or GLOBAL address space. * Destination tensor must be in SHARED address space. * For non-masked tensors, the fragment sizes must match. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the source tensor. This determines how the workload is distributed among threads. * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the destination tensor. Defaults to the same as `src_thread_layout` if not specified. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of `src_thread_layout`. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Scope at which thread operations are performed (`BLOCK` or `WARP`). Defaults to `ThreadScope.BLOCK`, where all threads in a block participate. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in global or generic memory (DRAM). `copy_dram_to_sram[src_thread_layout: Layout, dst_thread_layout: Layout = src_thread_layout, swizzle: OptionalReg[Swizzle] = None, num_threads: Int = src_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], bound: Int)` Efficiently copy data from global memory (DRAM) to shared memory (SRAM) on AMD GPUs. This function implements an optimized memory transfer operation specifically for AMD GPU architectures. It utilizes the hardware's `buffer_load` intrinsic to efficiently transfer data while handling bounds checking. The function distributes the copy operation across multiple threads for maximum throughput. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the source tensor across threads. This determines how the workload is divided among participating threads. * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the destination tensor across threads. Defaults to the same layout as `src_thread_layout`. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzling pattern to apply when distributing the destination tensor. This can improve memory access patterns and reduce bank conflicts. Defaults to None (no swizzling). * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): The total number of threads participating in the copy operation. Defaults to the size of `src_thread_layout`. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory (SRAM). * ​src\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): The source tensor iterator in global memory (DRAM) to be copied. * ​bound ([`Int`](/mojo/stdlib/builtin/int/Int)): The bound of the source tensor iterator. `copy_dram_to_sram[thread_layout: Layout, swizzle: OptionalReg[Swizzle] = None, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], bound: Int)` Synchronously copy data from DRAM to SRAM using a unified thread layout for AMD GPUs. This is a convenience wrapper around the more general `copy_dram_to_sram()` function that uses the same layout for both source and destination tensors. It's specifically designed for AMD GPUs where the buffer\_load intrinsic requires the original base tensor. Performance: * Simplifies API usage when the same thread layout is appropriate for both source and destination tensors. * Optimized for AMD GPUs using buffer\_load intrinsics for efficient memory transfers. * Distributes the copy workload across multiple threads for parallel execution. Notes: * This function is only supported on AMD GPUs. * The source tensor must be in GENERIC or GLOBAL address space (DRAM). * The destination tensor must be in SHARED address space (SRAM). * Both tensors must have the same data type. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for both source and destination. This determines how the workload is distributed among threads. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of thread\_layout. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Scope at which thread operations are performed (`BLOCK` or `WARP`). Defaults to `BLOCK`, where all threads in a block participate. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): The source tensor iterator, which must be in global or generic memory (DRAM). * ​bound ([`Int`](/mojo/stdlib/builtin/int/Int)): The bound of the source tensor iterator. `copy_dram_to_sram[thread_layout: Layout, swizzle: OptionalReg[Swizzle] = None, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from DRAM to SRAM using a unified thread layout. This is a convenience wrapper around the more general `copy_dram_to_sram()` function that uses the same layout for both source and destination tensors. It simplifies the API for the common case where the same thread distribution pattern works well for both tensors. Performance: * Simplifies API usage when the same thread layout is appropriate for both source and destination tensors. * Distributes the copy workload across multiple threads for parallel execution. * Supports vectorized loads and stores for better memory throughput. * Can use swizzling to optimize memory access patterns and reduce bank conflicts. Notes: * The source tensor must be in `GENERIC` or `GLOBAL` address space (DRAM). * The destination tensor must be in `SHARED` address space (SRAM). * Both tensors must have the same data type. * This function is synchronous, meaning all threads must complete their copy operations before proceeding. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for both source and destination. This determines how the workload is distributed among threads. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of `thread_layout`. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Scope at which thread operations are performed (`BLOCK` or `WARP`). Defaults to `ThreadScope.BLOCK`, where all threads in a block participate. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in global or generic memory (DRAM).
--- ## copy_dram_to_sram_async
`copy_dram_to_sram_async[src_thread_layout: Layout, dst_thread_layout: Layout, swizzle: Bool = False, fill: Fill = Fill.NONE, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_threads: Int = src_thread_layout.size(), block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Asynchronously copy data from DRAM (global memory) to SRAM (shared memory) in a GPU context. This function performs an asynchronous copy operation from global memory (DRAM) to shared memory (SRAM) in a GPU context, using NVIDIA's cp.async hardware mechanism. It distributes the workload across multiple threads and allows computation to overlap with memory transfers for improved performance. Performance: * Performs asynchronous transfers, allowing computation to overlap with memory operations. * Distributes the copy workload across multiple threads for parallel execution. * Can use swizzling to optimize memory access patterns and reduce bank conflicts. * Supports different cache eviction policies to optimize memory hierarchy usage. * For masked tensors, performs bounds checking to handle edge cases correctly. Notes: * This function requires NVIDIA GPUs with `cp.async` support (compute capability 8.0+). * The source tensor must be in GENERIC or GLOBAL address space (DRAM). * The destination tensor must be in SHARED address space (SRAM). * Both tensors must have the same data type. * This function is asynchronous, so you must call [`async_copy_wait_all()`](/mojo/stdlib/gpu/memory/async_copy_wait_all/) or [`async_copy_wait_group()`](/mojo/stdlib/gpu/memory/async_copy_wait_group/) to ensure the copy has completed before using the data. * The maximum size of each element that can be copied is 16 bytes. **Constraints:** * Requires NVIDIA GPUs with cp.async support (compute capability 8.0+). * Source tensor must be in `GENERIC` or `GLOBAL` address space. * Destination tensor must be in `SHARED` address space. * Both tensors must have the same data type. * Element size must be 4, 8, or 16 bytes. **Parameters:** * ​src\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the source tensor. This determines how the workload is distributed among threads. * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the destination tensor. * ​swizzle ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to apply swizzling to the destination indices to reduce bank conflicts. Defaults to False. * ​fill ([`Fill`](/mojo/stdlib/gpu/memory/memory/Fill)): Fill policy for handling out-of-bounds accesses. Options include: * `Fill.NONE`: No special handling (default). * `Fill.ZERO`: Fill out-of-bounds elements with zeros. * ​eviction\_policy ([`CacheEviction`](/mojo/stdlib/gpu/memory/memory/CacheEviction)): Cache eviction policy for the source data. Options include: * `CacheEviction.EVICT_NORMAL`: Normal eviction (default). * `CacheEviction.EVICT_FIRST`: Evict data after first use. * `CacheEviction.EVICT_LAST`: Keep data in cache until last use. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of src\_thread\_layout. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in global or generic memory (DRAM). `copy_dram_to_sram_async[thread_layout: Layout, swizzle: Bool = False, masked: Bool = False, fill: Fill = Fill.NONE, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL, num_threads: Int = thread_layout.size(), block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Asynchronous copy from DRAM to SRAM with thread affinity mapping. This function performs an asynchronous memory transfer from DRAM (global memory) to SRAM (shared memory) using the specified thread layout for distribution. Notes: This is a convenience wrapper around the more general `copy_dram_to_sram_async()` function, using the same thread layout for both source and destination. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute work across threads. * ​swizzle ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to apply memory access swizzling for better performance. * ​masked ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the copy operation should use masking. * ​fill ([`Fill`](/mojo/stdlib/gpu/memory/memory/Fill)): Fill policy for uninitialized memory regions. * ​eviction\_policy ([`CacheEviction`](/mojo/stdlib/gpu/memory/memory/CacheEviction)): Cache eviction policy to use during the transfer. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of threads to use for the operation, defaults to the size of `thread_layout`. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor in SRAM. * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tensor in DRAM.
--- ## copy_local_to_dram
`copy_local_to_dram[dst_thread_layout: Layout, num_threads: Int = dst_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Efficiently copy data from registers (LOCAL) to global memory (DRAM). This function implements a high-performance memory transfer operation from register memory to global memory. It distributes the copy operation across multiple threads for maximum throughput while handling bounds checking for safety. **Constraints:** * The source tensor must be in LOCAL address space (registers). * The destination tensor must be in GENERIC or GLOBAL address space (DRAM). * Both tensors must have compatible data types. **Parameters:** * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the destination tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of thread\_layout. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in global memory (DRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in register memory (LOCAL) to be copied. `copy_local_to_dram[dst_thread_layout: Layout, num_threads: Int = dst_thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst_base: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Efficiently copy data from registers (LOCAL) to global memory (DRAM) on AMD GPUs. This function implements an optimized memory transfer operation specifically for AMD GPU architectures. It utilizes the hardware's buffer\_store intrinsic to efficiently transfer data from registers to global memory while handling bounds checking. The function distributes the copy operation across multiple threads for maximum throughput. Notes: * This function is particularly useful for writing computed results from registers back to global memory with minimal latency. * The offset calculation is optimized for performance rather than flexibility. **Constraints:** * Only supported on AMD GPUs. * Destination tensor must be in GLOBAL address space. * Source tensor must be in LOCAL address space. * Data types must match between source and destination tensors. **Parameters:** * ​dst\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout used to distribute the destination tensor across threads. This determines how the workload is divided among participating threads. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of thread\_layout. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in global memory (DRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in register memory (LOCAL address space) to be copied. * ​dst\_base ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The original global memory tensor from which dst is derived. This is used to construct the buffer descriptor required by AMD's `buffer_store` intrinsic.
--- ## copy_local_to_local
`copy_local_to_local(dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data between local memory (register) tensors with type conversion. This function performs a synchronous copy operation between register tensors in a GPU context, with support for converting from float32 to half-precision formats (bfloat16/float16). It's particularly optimized for specific tensor layouts commonly used in matrix multiplication operations. Example: ```mojo from layout import LayoutTensor, Layout from layout.layout_tensor import copy_local_to_local fn kernel(): ... var src_reg = LayoutTensor[DType.float32, Layout.row_major(16, 8), MutAnyOrigin, address_space = AddressSpace.LOCAL, ].stack_allocation().fill(1) var dst_reg = LayoutTensor[DType.bfloat16, Layout.row_major(16, 8), MutAnyOrigin, address_space = AddressSpace.LOCAL, ].stack_allocation() # Process data in float32 registers # ... # Convert and copy to bfloat16 registers copy_local_to_local(dst_reg, src_reg) ``` Performance: * Optimized for specific 2D tensor layouts with contiguous inner dimensions. * Special fast path for 2D tensors with specific layouts used in matrix multiplication. * For MMA (Matrix Multiply-Accumulate) operations, efficiently handles the conversion between output fragments and input fragments with different layouts. * Falls back to element-wise copy for general cases. Notes: * Both source and destination tensors must be in `LOCAL` address space (registers). * This function currently only supports copying from float32 to half-precision formats. * For 2D tensors with stride\[1] == 1, a specialized fast path is used that's optimized for matrix multiplication patterns. * This function is particularly useful in GPU kernels for converting between different precision formats while keeping data in registers. **Constraints:** * Destination tensor must be in `LOCAL` address space. * Source tensor must be in `LOCAL` address space. * Destination tensor must have a half-precision floating-point data type. * Source tensor must have float32 data type. * Both tensors must have the same total size. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in local memory (registers) and have a half-precision floating-point data type (bfloat16 or float16). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in local memory (registers) and have float32 data type.
--- ## copy_local_to_shared
`copy_local_to_shared[thread_layout: Layout, swizzle: OptionalReg[Swizzle] = None, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1, *, row_major: Bool = False](dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from local memory (registers) to SRAM (shared memory). This function performs a synchronous copy operation from register memory to shared memory in a GPU context, distributing the workload across multiple threads for parallel execution. It's particularly useful for transferring processed data from registers to shared memory for inter-thread communication. Performance: * Distributes the copy workload across multiple threads for parallel execution. * Can use swizzling to optimize memory access patterns and reduce bank conflicts. * Optimized for transferring data from registers to shared memory. * On AMD GPUs, the `row_major` parameter can be used to match the memory access pattern used during prefetching from DRAM to registers. Notes: * The destination tensor must be in `SHARED` address space (SRAM). * The source tensor must be in `LOCAL` address space (registers). * This function is particularly useful in GPU kernels for sharing processed data between threads in the same block. * The `row_major` parameter is specifically designed for AMD GPUs when using a prefetching pattern from DRAM to SRAM via registers. **Constraints:** * Destination tensor must be in SHARED address space. * Source tensor must be in LOCAL address space. * For optimal performance, the thread layout should match the memory access patterns of the tensors. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the operation. This determines how the workload is distributed among threads. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzling function to rearrange the destination indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of thread\_layout. * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Defines whether operations are performed at `BLOCK` or `WARP` level. `BLOCK` scope involves all threads in a thread block, while `WARP` scope restricts operations to threads within the same warp. Defaults to `ThreadScope.BLOCK`. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. * ​row\_major ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to use row-major ordering for the copy operation. This is particularly relevant when prefetching from DRAM to SRAM via registers on AMD GPUs. Defaults to False. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in local memory (registers).
--- ## copy_sram_to_dram
`copy_sram_to_dram[thread_layout: Layout, swizzle: OptionalReg[Swizzle] = None, num_threads: Int = thread_layout.size(), block_dim_count: Int = 1, binary_op: OptionalReg[fn[dtype: DType, width: Int](lhs: SIMD[dtype, width], rhs: SIMD[dtype, width]) -> SIMD[dtype, width]] = None](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from SRAM (shared memory) to DRAM (global memory). This function performs a synchronous memory transfer from SRAM (shared memory) to DRAM (global memory) using the specified thread layout for workload distribution. It supports optional swizzling for optimized memory access patterns and binary operations for combining data during the transfer. Performance: * Distributes the copy workload across multiple threads for parallel execution. * Supports vectorized loads and stores for better memory throughput. * Can use swizzling to optimize memory access patterns. * Supports binary operations to combine data during transfer (e.g., for reduction operations). Notes: * The source tensor must be in `SHARED` address space (SRAM). * The destination tensor must be in `GENERIC` or `GLOBAL` address space (DRAM). * Supports FP32 to half-precision downcast during copy if needed. * Handles masked tensors with proper bounds checking. * This function is synchronous, meaning all threads must complete their copy operations before proceeding. **Constraints:** * Source tensor must be in SHARED address space with a static layout. * Destination tensor must be in GENERIC or GLOBAL address space. * For type conversion, only FP32 to half-precision is supported. * For vectorized copy with type conversion, both tensors must have element layouts matching the SIMD width of the destination type. **Parameters:** * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for both source and destination. This determines how the workload is distributed among threads. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzling function to rearrange the source indices, which can improve memory access patterns and reduce bank conflicts. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of threads participating in the copy operation. Defaults to the size of thread\_layout. * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the thread block. * ​binary\_op ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional binary operation to apply during the copy, combining source data with existing destination data. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in global or generic memory (DRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in shared memory (SRAM).
--- ## copy_sram_to_local
`copy_sram_to_local[src_warp_layout: Layout, axis: OptionalReg[Int] = None](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Synchronously copy data from SRAM (shared memory) to local memory. This function performs a synchronous memory transfer from SRAM (shared memory) to local memory (registers) using the specified thread layout for workload distribution. Performance: * Distributes the copy workload across multiple threads for parallel execution. * Optimized for transferring data from shared memory to registers. * Supports optional axis-specific distribution for specialized access patterns. **Constraints:** * The source tensor must be in SHARED address space (SRAM). * The destination tensor must be in LOCAL address space (registers). * Both tensors must have the same data type. **Parameters:** * ​src\_warp\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout defining how threads are organized for the source tensor. This determines how the workload is distributed among threads. * ​axis ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional parameter specifying which axis to distribute along. When provided, distribution happens along the specified axis. When None (default), distribution uses the standard layout pattern. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in local memory (registers). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in shared memory (SRAM).
--- ## cp_async_k_major
`cp_async_k_major[dtype: DType, eviction_policy: CacheEviction = CacheEviction.EVICT_NORMAL](dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Asynchronously copy data from DRAM to SRAM using TMA (Tensor Memory Accelerator) with K-major layout. This function performs an asynchronous copy operation from global memory (DRAM) to shared memory (SRAM) using NVIDIA's Tensor Memory Accelerator (TMA) hardware. It optimizes for K-major memory access patterns, which is particularly beneficial for certain tensor operations like matrix multiplications where the inner dimension (K) is accessed contiguously. The function automatically determines the optimal tile size and thread distribution based on the tensor shapes and hardware capabilities, leveraging TMA's efficient memory transfer mechanisms. Performance: * Uses TMA hardware acceleration for optimal memory transfer performance. * Optimizes for K-major access patterns, which can significantly improve performance for certain tensor operations like matrix multiplications. * Performs asynchronous transfers, allowing computation to overlap with memory operations. * Automatically determines optimal tile sizes based on tensor dimensions. * Uses hardware-accelerated swizzling to reduce shared memory bank conflicts. Notes: * This function requires NVIDIA GPUs with TMA support (compute capability 9.0+). * The source tensor must be in GENERIC or GLOBAL address space (DRAM). * The destination tensor must be in SHARED address space (SRAM). * Both tensors must have the same data type. * This function is asynchronous, so you must call [`async_copy_wait_all()`](/mojo/stdlib/gpu/memory/async_copy_wait_all/) or [`async_copy_wait_group()`](/mojo/stdlib/gpu/memory/async_copy_wait_group/) to ensure the copy has completed before using the data. * K-major layout is particularly beneficial for matrix multiplication operations where the inner dimension (K) is accessed contiguously. **Constraints:** * Requires NVIDIA GPUs with TMA support (compute capability 9.0+). * Source tensor must be in GENERIC or GLOBAL address space. * Destination tensor must be in SHARED address space. * Both tensors must have the same data type. * Source and destination tensors must be 2D. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the tensor elements. * ​eviction\_policy ([`CacheEviction`](/mojo/stdlib/gpu/memory/memory/CacheEviction)): The cache eviction policy to use. Default is `CacheEviction.EVICT_NORMAL`. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor, which must be in shared memory (SRAM). * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor, which must be in global or generic memory (DRAM).
--- ## layout_tensor
Provides the `LayoutTensor` type for representing multidimensional data. ## `comptime` values ### `binary_op_type` `comptime binary_op_type = fn[dtype: DType, width: Int](lhs: SIMD[dtype, width], rhs: SIMD[dtype, width]) -> SIMD[dtype, width]` Type alias for binary operations on SIMD vectors. This type represents a function that takes two SIMD vectors of the same type and width and returns a SIMD vector of the same type and width. Args: dtype: The data type of the SIMD vector elements. width: The width of the SIMD vector. lhs: Left-hand side SIMD vector operand. rhs: Right-hand side SIMD vector operand. Returns: A SIMD vector containing the result of the binary operation. ## Structs * [​`LayoutTensor`](./LayoutTensor): A high-performance tensor with explicit memory layout and hardware-optimized access patterns. * [​`LayoutTensorIter`](./LayoutTensorIter): Iterator for traversing a memory buffer with a specific layout. * [​`ThreadScope`](./ThreadScope): Represents the scope of thread operations in GPU programming. ## Functions * [​`copy_dram_to_local`](./copy_dram_to_local): Efficiently copy data from global memory (DRAM) to registers for AMD GPUs. * [​`copy_dram_to_sram`](./copy_dram_to_sram): Synchronously copy data from DRAM (global memory) to SRAM (shared memory) in a GPU context. * [​`copy_dram_to_sram_async`](./copy_dram_to_sram_async): Asynchronously copy data from DRAM (global memory) to SRAM (shared memory) in a GPU context. * [​`copy_local_to_dram`](./copy_local_to_dram): Efficiently copy data from registers (LOCAL) to global memory (DRAM). * [​`copy_local_to_local`](./copy_local_to_local): Synchronously copy data between local memory (register) tensors with type conversion. * [​`copy_local_to_shared`](./copy_local_to_shared): Synchronously copy data from local memory (registers) to SRAM (shared memory). * [​`copy_sram_to_dram`](./copy_sram_to_dram): Synchronously copy data from SRAM (shared memory) to DRAM (global memory). * [​`copy_sram_to_local`](./copy_sram_to_local): Synchronously copy data from SRAM (shared memory) to local memory. * [​`cp_async_k_major`](./cp_async_k_major): Asynchronously copy data from DRAM to SRAM using TMA (Tensor Memory Accelerator) with K-major layout. * [​`stack_allocation_like`](./stack_allocation_like): Create a stack-allocated tensor with the same layout as an existing tensor.
--- ## stack_allocation_like
`stack_allocation_like[layout: Layout, dtype: DType, *, address_space: AddressSpace, target_address_space: AddressSpace = AddressSpace.GENERIC](in_tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=target_address_space, masked=masked]` Create a stack-allocated tensor with the same layout as an existing tensor. This function creates a new tensor on the stack with the same layout, data type, and masking properties as the input tensor, but potentially with a different address space. This is useful for creating temporary tensors that match the structure of existing tensors. Example: ```mojo from layout import LayoutTensor, Layout from layout.layout_tensor import stack_allocation_like var global_tensor = LayoutTensor[ DType.float32, Layout([10, 10]), MutAnyOrigin, address_space=AddressSpace.GLOBAL ].stack_allocation() var shared_tensor = stack_allocation_like[ target_address_space=AddressSpace.SHARED ](global_tensor) ``` Performance: * Creates a tensor on the stack, which is typically faster to allocate and access than heap-allocated memory. * Stack allocations have automatic lifetime management, reducing memory management overhead. * Stack size is limited, so be cautious with large tensor allocations. Notes: * The new tensor will have the same layout, data type, and masking properties as the input tensor. * The address space can be changed, which is useful for moving data between different memory regions (e.g., from global to shared memory). * Stack allocations are automatically freed when they go out of scope. * The function uses the stack\_allocation method of the result tensor type. **Parameters:** * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the tensor to allocate. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the tensor elements. * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The address space of the input tensor. * ​target\_address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The address space for the new tensor. Defaults to GENERIC. **Args:** * ​in\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to match the layout of. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor allocated on the stack with the same layout as the input tensor.
--- ## math
Implements math methods that work on layout tensors. ## Functions * [​`max`](./max): Computes maximum reduction along specified axis. * [​`mean`](./mean): Computes the mean value of the elements in a buffer. * [​`outer_product_acc`](./outer_product_acc): Updates result tensor with the outer product of two vectors. * [​`sum`](./sum): Computes sum reduction along specified axis. * [​`variance`](./variance): Computes the variance value of the elements in a buffer.
--- ## max
`max[axis: Int](inp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], outp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Computes maximum reduction along specified axis. Reduces the input tensor by taking maximum elements along the specified axis and stores the result in the output tensor. **Constraints:** All tensors must have statically known shapes. `outp.rank` must equal `inp.rank - 1`. Non-reduction dimensions must match between `inp` and `outp`. Currently only supports rank-2 inputs. **Parameters:** * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to take maximum along. **Args:** * ​inp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to reduce. * ​outp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor to store maximum results. `max[axis: Int](inp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, _reduce_res_row_major_shape(axis, layout), MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Computes maximum reduction along specified axis, returning a new tensor. Reduces the input tensor by taking maximum elements along the specified axis and returns a new tensor with the results. **Constraints:** All tensors must have statically known shapes. Result will have rank equal to `inp.rank` - 1. Non-reduction dimensions in the result match the input. Currently only supports rank-2 inputs. **Parameters:** * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to take maximum along. **Args:** * ​inp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to reduce. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the maximum values along the specified axis. `max[dtype: DType, layout: Layout](x: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], y: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].MutableAnyType` Computes element-wise maximum of two tensors. Returns a new tensor containing the element-wise maximum between the input tensors. **Constraints:** Input tensors must have statically known shapes and matching layouts. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the input tensors. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the input tensors. **Args:** * ​x ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): First input tensor. * ​y ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Second input tensor. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the element-wise maximum.
--- ## mean
`mean(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> Scalar[dtype]` Computes the mean value of the elements in a buffer. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The buffer of elements for which the mean is computed. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The mean value of the elements in the given buffer. **Raises:** May raise on GPU targets when a device error occurs. `mean[reduce_axis: Int](src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Computes the mean across reduce\_axis of an NDBuffer. **Parameters:** * ​reduce\_axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to reduce across. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer. * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer. **Raises:** May raise on GPU targets when a device error occurs.
--- ## outer_product_acc
`outer_product_acc(res: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], lhs: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], rhs: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Updates result tensor with the outer product of two vectors. Computes `res += outer(lhs, rhs)` where `lhs` and `rhs` are vectors and `res` is a matrix. **Constraints:** All tensors must have statically known shapes. `res` must be rank 2. `lhs` and `rhs` must be rank 1. `res.shape[0]` `==` `lhs.shape[0]` and `res.shape[1]` `==` `rhs.shape[0]`. **Args:** * ​res ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The result matrix to accumulate into, shape (M, N). * ​lhs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The left-hand side vector, shape (M,). * ​rhs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The right-hand side vector, shape (N,).
--- ## sum (Math)
`sum[axis: Int](inp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], outp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Computes sum reduction along specified axis. Reduces the input tensor by summing elements along the specified axis and stores the result in the output tensor. Example: ```mojo from layout import LayoutTensor, Layout from layout.math import sum data = InlineArray[Int32, 6](0, 1, 2, 3, 4, 5) tensor = LayoutTensor[DType.int32, Layout.row_major(2, 3)](data) print(tensor) print("-----") print(sum[0](tensor)) ``` Output: ```plaintext 0 1 2 3 4 5 ----- 3 5 7 ``` **Constraints:** All tensors must have statically known shapes. `outp.rank` must equal `inp.rank - 1`. Non-reduction dimensions must match between inp and outp. Currently only supports rank-2 inputs. **Parameters:** * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to sum along. **Args:** * ​inp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to sum. * ​outp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor to store sum results. `sum[axis: Int](inp: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, _reduce_res_row_major_shape(axis, layout), MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type]` Computes sum reduction along specified axis, returning a new tensor. Reduces the input tensor by summing elements along the specified axis and returns a new tensor with the results. **Constraints:** All tensors must have statically known shapes. Result will have rank equal to `inp.rank` - 1. Non-reduction dimensions in the result match the input. Currently only supports rank-2 inputs. **Parameters:** * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to sum along. **Args:** * ​inp ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor to sum. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): A new tensor containing the sum values along the specified axis.
--- ## variance
`variance(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], correction: Int = 1) -> Scalar[dtype]` Computes the variance value of the elements in a buffer. ``` variance(x) = sum((x - E(x))^2) / (size - correction) ``` **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The buffer. * ​correction ([`Int`](/mojo/stdlib/builtin/int/Int)): Normalize variance by size - correction (Default=1). **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The variance value of the elements in a buffer. **Raises:** May raise on GPU targets when a device error occurs.
--- ## RuntimeLayout
`@register_passable(trivial)` `struct RuntimeLayout[layout: Layout, /, *, element_type: DType = DType.int64, linear_idx_type: DType = DType.int64]` A runtime-configurable layout that uses `RuntimeTuple` for storage. This struct provides a layout implementation that can be modified at runtime, unlike the static [`Layout`](/mojo/kernels/layout/layout/Layout) type. It uses [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple) for shape and stride storage. The layout must have statically known dimensions at compile time, but the actual shape and stride values can be modified during execution. ## Parameters * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The static `Layout` type to base this runtime layout on. * ​element\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer type of the each dimension element. Must be signed. * ​linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer type of the linear index into memory returned by `crd2idx`. Must be signed. ## Fields * ​shape (`RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type].ShapeType`): The shape of the layout as a runtime tuple. Stores the size of each dimension. Uses the specified bitwidth and is unsigned. Must match the static layout's shape dimensions. * ​stride (`RuntimeLayout[layout, element_type=element_type, linear_idx_type=linear_idx_type].StrideType`): The stride of the layout as a runtime tuple. Stores the stride (step size) for each dimension. Uses 64-bit unsigned integers since strides can be large values. Must match the static layout's stride dimensions. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ShapeType` `comptime ShapeType = RuntimeTuple[layout.shape, element_type=element_type]` Type alias for the runtime shape tuple. ### `StrideType` `comptime StrideType = RuntimeTuple[layout.stride, element_type=linear_idx_type]` Type alias for the runtime stride tuple. ## Methods ### `__init__` `__init__() -> Self` Initialize a `RuntimeLayout` with default values. Creates a new `RuntimeLayout` instance with default shape and stride values. Requires that the static layout has known dimensions at compile time. **Constraints:** The static layout that this runtime layout is based on must have all dimensions known. `__init__(shape: RuntimeTuple[layout.shape, element_type=element_type], stride: RuntimeTuple[layout.stride, element_type=linear_idx_type]) -> Self` Initialize a `RuntimeLayout` with specified shape and stride. **Args:** * ​shape ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): A `RuntimeTuple` containing the dimensions of each axis. * ​stride ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): A `RuntimeTuple` containing the stride values for each axis. ### `__call__` `__call__(self, idx: Int) -> Scalar[linear_idx_type]` Convert a single index to a flat linear index. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The one-dimensional index to convert. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The corresponding flat linear index in the layout. `__call__[t: IntTuple](self, idx: RuntimeTuple[t, element_type=element_type]) -> Scalar[linear_idx_type]` Convert a multi-dimensional index to a flat linear index. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` type for the index. **Args:** * ​idx ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): A `RuntimeTuple` containing the multi-dimensional coordinates. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The corresponding flat linear index in the layout. ### `idx2crd` `idx2crd[t: IntTuple](self, idx: RuntimeTuple[t, element_type=element_type]) -> RuntimeTuple[idx2crd(t, layout.shape, layout.stride), element_type=element_type]` Converts a linear index to logical coordinates. This is the inverse operation of the **call** method, mapping from a memory index back to the corresponding logical coordinates. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` type for the index. **Args:** * ​idx ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The linear index to convert. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): The logical coordinates corresponding to the given index. ### `size` `size(self) -> Int` Calculate the total number of elements in the layout. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The product of all dimensions in the shape, representing the total number of elements that can be addressed by this layout. ### `bound_check_required` `bound_check_required(self) -> Bool` Determine if bounds checking is required for this layout. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if any dimension in the shape differs from the static layout's shape, False otherwise. ### `cast` `cast[_element_type: DType, /, *, target_linear_idx_type: DType = linear_idx_type](self) -> RuntimeLayout[layout, element_type=_element_type, linear_idx_type=target_linear_idx_type]` Cast the layout to use a different element bitwidth. **Parameters:** * ​\_element\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The target data type. * ​target\_linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The target linear idx type. **Returns:** [`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout): A new `RuntimeLayout` with the shape cast to the specified type. ### `__str__` `__str__(self) -> String` Convert the layout to a string representation. **Returns:** `String`: A string representation of the layout. ### `row_major` `static row_major[rank: Int, //](shape: IndexList[rank, element_type=element_type]) -> Self` Create a row-major layout from the given shape. In row-major layout, elements with adjacent rightmost indices are adjacent in memory. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the layout. **Args:** * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): An `IndexList` containing the dimensions of each axis. **Returns:** `Self`: A `RuntimeLayout` with row-major stride ordering. ### `col_major` `static col_major[rank: Int, //](shape: IndexList[rank, element_type=element_type]) -> Self` Create a column-major layout from the given shape. In column-major layout, elements with adjacent leftmost indices are adjacent in memory. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions in the layout. **Args:** * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): An `IndexList` containing the dimensions of each axis. **Returns:** `Self`: A `RuntimeLayout` with column-major stride ordering. ### `write_to` `write_to(self, mut writer: T)` Write a string representation of the layout to a writer. **Args:** * ​writer (`T`): The `Writer` object to write the layout representation to. ### `sublayout` `sublayout[i: Int](self) -> RuntimeLayout[layout[i], element_type=element_type, linear_idx_type=linear_idx_type]` Extract a nested sublayout at the specified index. **Parameters:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the nested layout to extract. **Returns:** [`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout): A `RuntimeLayout` representing the nested layout at index i. ### `dim` `dim(self, i: Int) -> Int` Get the size of the dimension at the specified index. **Args:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the dimension to retrieve. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the dimension at index `i`. ### `__len__` `static __len__() -> Int` Get the number of dimensions in the layout. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of dimensions (rank) of the layout.
--- ## coalesce (Runtime_layout)
`coalesce[l: Layout, keep_rank: Bool = False](layout: RuntimeLayout[l, element_type=element_type, linear_idx_type=linear_idx_type]) -> RuntimeLayout[coalesce(l, keep_rank), element_type=element_type, linear_idx_type=linear_idx_type]` Coalesce adjacent dimensions in a runtime layout when possible. This optimizes the layout by merging adjacent dimensions when their relationship allows it, potentially reducing the number of dimensions. **Parameters:** * ​l ([`Layout`](/mojo/kernels/layout/layout/Layout)): The static layout type to coalesce. * ​keep\_rank ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to maintain the original rank (currently unsupported). **Args:** * ​layout ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The input `RuntimeLayout` to coalesce. **Returns:** [`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout): A new `RuntimeLayout` with coalesced dimensions.
--- ## runtime_layout
Provides the `RuntimeLayout` type and functions for working with it. You can use `RuntimeLayout` to define a layout where the dimensions are not known at compile time. You can import these APIs from `layout.runtime_layout`. ```mojo from layout.runtime_layout import RuntimeLayout, make_layout ``` ## Structs * [​`RuntimeLayout`](./RuntimeLayout): A runtime-configurable layout that uses `RuntimeTuple` for storage. ## Functions * [​`coalesce`](./coalesce): Coalesce adjacent dimensions in a runtime layout when possible. * [​`make_layout`](./make_layout): Combine two runtime layouts into a single composite layout.
--- ## make_layout (Runtime_layout)
`make_layout[l1: Layout, l2: Layout, /, *, linear_idx_type: DType = DType.uint64](a: RuntimeLayout[l1, element_type=element_type, linear_idx_type=linear_idx_type], b: RuntimeLayout[l2, element_type=element_type, linear_idx_type=linear_idx_type]) -> RuntimeLayout[make_layout(l1, l2), element_type=element_type, linear_idx_type=linear_idx_type]` Combine two runtime layouts into a single composite layout. This creates a new layout by concatenating the dimensions and strides of the input layouts. **Parameters:** * ​l1 ([`Layout`](/mojo/kernels/layout/layout/Layout)): The static layout type of `a`. * ​l2 ([`Layout`](/mojo/kernels/layout/layout/Layout)): The static layout type of `b`. * ​linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer type of the all index calculated by the returned runtime layout. **Args:** * ​a ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The first `RuntimeLayout` to combine. * ​b ([`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout)): The second `RuntimeLayout` to combine. **Returns:** [`RuntimeLayout`](/mojo/kernels/layout/runtime_layout/RuntimeLayout): A new `RuntimeLayout` with dimensions from both input layouts.
--- ## RuntimeTuple
`@register_passable(trivial)` `struct RuntimeTuple[S: IntTuple = -1, /, *, element_type: DType = DType.int64]` A struct representing tuple-like data with compile-time and runtime elements. RuntimeTuple combines static (compile-time) and dynamic (runtime) handling of tuple-like data structures, typically used for tensor shapes, indices, and coordinates in high-performance computing contexts. This struct is optimized for parallel execution and hardware acceleration, allowing efficient manipulation of multi-dimensional data. It supports both known compile-time values and runtime-determined values. ## Parameters * ​S ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): `IntTuple` with compile-time known values (or `UNKNOWN_VALUE` for runtime values). * ​element\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Integer type of the underlying elements. ## Fields * ​value (`IndexList[RuntimeTuple[S, element_type=element_type].scalar_length, element_type=element_type]`): Storage for the actual tuple values, implemented as an IndexList with the appropriate size and element type. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `scalar_length` `comptime scalar_length = len[IntTuple](flatten(S))` The total number of scalar elements in this RuntimeTuple after flattening nested tuples. ## Methods ### `__init__` `__init__() -> Self` Initialize a `RuntimeTuple` with default values. For dimensions with known compile-time values in S, uses those values. For unknown dimensions, initializes them to UNKNOWN\_VALUE. `__init__(*values: Int) -> Self` Initialize a `RuntimeTuple` with the provided values. **Args:** * ​\*values ([`Int`](/mojo/stdlib/builtin/int/Int)): Variadic number of integer values to initialize the tuple with. `@implicit` `__init__[l: Int](values: IndexList[l, element_type=element_type]) -> Self` Initialize a `RuntimeTuple` from an `IndexList`. **Parameters:** * ​l ([`Int`](/mojo/stdlib/builtin/int/Int)): Compile-time length of the input `IndexList`. **Args:** * ​values ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): `IndexList` to initialize from. Must have same length as the `RuntimeTuple`. The values will be cast to the appropriate element type if needed. ### `__getitem__` `__getitem__[i: Int](self) -> RuntimeTuple[S[i], element_type=element_type]` Retrieves the element at the specified index in the tuple. This method provides array-like indexing for RuntimeTuple, allowing access to individual elements or sub-tuples. It handles the internal offset calculation to access the correct elements in the flattened storage array. **Parameters:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to retrieve. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing the element or sub-tuple at the specified index. ### `__setitem__` `__setitem__[i: Int](mut self, val: Scalar[element_type])` Sets the value of the element at the specified index in the tuple. This method enables array-like assignment for RuntimeTuple elements, handling the internal offset calculation to modify the correct element in the flattened storage array. **Parameters:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to modify. **Args:** * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The new value to assign to the element. ### `offset_until` `static offset_until[i: Int]() -> Int` Calculates the offset in the flattened value array for a given tuple index. This method computes the sum of lengths of all flattened subtuple elements that come before the specified index, which is used for indexing into the internal storage. **Parameters:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The tuple index to calculate the offset for. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The offset in the flattened array where the i-th element begins. ### `get_int` `get_int(self) -> Scalar[element_type]` Returns the integer value of this RuntimeTuple. For tuples with a known compile-time value, returns that value. For tuples with a runtime value, returns the first element of the internal storage array. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The integer value of this RuntimeTuple. ### `__str__` `__str__(self) -> String` Converts the RuntimeTuple to its string representation. This method provides a human-readable string representation of the tuple, which is useful for debugging and logging. **Returns:** `String`: A string representation of the `RuntimeTuple`. ### `concat` `concat[R: IntTuple](self, rhs: RuntimeTuple[R, element_type=element_type]) -> RuntimeTuple[concat(S, R), element_type=element_type]` Concatenates two `RuntimeTuple`s together. This method combines the current `RuntimeTuple` with another one, preserving both compile-time and runtime values. It handles the complexity of merging the underlying storage arrays while maintaining the proper semantic structure. **Parameters:** * ​R ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The `IntTuple` type parameter of the right-hand side RuntimeTuple. **Args:** * ​rhs ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The `RuntimeTuple` to concatenate to the end of this one. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing all elements from both tuples in sequence. ### `flatten` `flatten(self) -> RuntimeTuple[flatten(S), element_type=element_type]` Flattens a potentially nested `RuntimeTuple` into a single-level tuple. This method converts a hierarchical structure of tuples into a flat representation, preserving all values but removing the nested structure. This is useful for operations that need to treat all elements uniformly. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing all elements in a flat (non-nested) structure. ### `write_to` `write_to(self, mut writer: T)` Writes the RuntimeTuple to a Writer object. This method is used by the string conversion system to generate a string representation of the RuntimeTuple. It handles both scalar values and nested tuple structures, producing a properly formatted output. **Args:** * ​writer (`T`): The Writer object to write the string representation to. ### `__len__` `__len__(self) -> Int` Returns the length (number of top-level elements) of the `RuntimeTuple`. This method provides the standard Python-like len() functionality, giving the number of elements at the top level of the tuple structure. For nested tuples, this returns the number of first-level entries, not the total number of scalar values. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of top-level elements in the tuple. ### `cast` `cast[dtype: DType](self) -> RuntimeTuple[S, element_type=dtype]` Casts the RuntimeTuple to use a different numeric type. This method creates a new RuntimeTuple with the same structure and values but using a different underlying numeric type for storage. This is useful for changing precision or signedness of the data. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The target DType to cast the elements to. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` with elements cast to the specified type. ### `__int__` `__int__(self) -> Int` Converts the RuntimeTuple to an integer value. This method enables implicit conversion of a RuntimeTuple to an integer, but is constrained to only work on scalar tuples (those that contain a single value). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integer value of the tuple.
--- ## coalesce_nested_tuple
`coalesce_nested_tuple[t: IntTuple, out_t: IntTuple = _int_tuple_product_flatten[t]()](tuple: RuntimeTuple[t, element_type=element_type]) -> RuntimeTuple[out_t]` Coalesces a nested `RuntimeTuple` into a single-level `RuntimeTuple`, by multiplying all the values together. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The underlying Compile-time IntTuple backing the RuntimeTuple. * ​out\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The flattened Compile-time IntTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The RuntimeTuple to convert. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `IntTuple` containing the products of each top level tuple, in a flat structure.
--- ## concat
`concat(var lhs: IntTuple, rhs: IntTuple) -> IntTuple` Concatenates two `IntTuple` instances into a single `IntTuple`. This function appends all elements from the right-hand side tuple to the left-hand side tuple, creating a new combined tuple. The operation preserves the hierarchical structure of both tuples. **Args:** * ​lhs ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The left-hand side `IntTuple` that will be modified (var parameter). * ​rhs ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The right-hand side `IntTuple` whose elements will be appended. **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple): A new `IntTuple` containing all elements from both tuples in sequence.
--- ## crd2idx (Runtime_tuple)
`crd2idx[crd_t: IntTuple, shape_t: IntTuple, stride_t: IntTuple, out_type: DType = DType.uint64](crd: RuntimeTuple[crd_t, element_type=element_type], shape: RuntimeTuple[shape_t, element_type=element_type], stride: RuntimeTuple[stride_t, element_type=element_type]) -> Scalar[out_type]` Converts multi-dimensional coordinates to a linear index. This function is the inverse of idx2crd, transforming a set of coordinates into a flat index based on the provided shape and stride information. This is essential for mapping multi-dimensional tensor elements to linear memory. **Parameters:** * ​crd\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the coordinates. * ​shape\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the shape. * ​stride\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the stride. * ​out\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The output data type for the index (default: uint64). **Args:** * ​crd ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The coordinates to convert. * ​shape ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The shape of the multi-dimensional array. * ​stride ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The stride values for each dimension. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): A scalar value representing the linear index corresponding to the given coordinates.
--- ## idx2crd (Runtime_tuple)
`idx2crd[idx_t: IntTuple, shape_t: IntTuple, stride_t: IntTuple](idx: RuntimeTuple[idx_t, element_type=element_type], shape: RuntimeTuple[shape_t, element_type=element_type], stride: RuntimeTuple[stride_t, element_type=element_type]) -> RuntimeTuple[idx2crd(idx_t, shape_t, stride_t), element_type=element_type]` Converts a linear index to multi-dimensional coordinates. This function transforms a flat index into coordinate values based on the provided shape and stride information. This is essential for mapping linear memory accesses to multi-dimensional tensor elements. **Constraints:** The index must be a scalar value (not a tuple). **Parameters:** * ​idx\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the index. * ​shape\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the shape. * ​stride\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the stride. **Args:** * ​idx ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The linear index to convert. * ​shape ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The shape of the multi-dimensional array. * ​stride ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The stride values for each dimension. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A `RuntimeTuple` containing the multi-dimensional coordinates. `idx2crd[idx_t: IntTuple, shape_t: IntTuple](idx: RuntimeTuple[idx_t, element_type=element_type], shape: RuntimeTuple[shape_t, element_type=element_type]) -> RuntimeTuple[idx2crd(idx_t, shape_t, prefix_product(shape_t)), element_type=element_type]` Converts a linear index to multi-dimensional coordinates using shape-derived strides. This is a convenience overload of `idx2crd` that automatically calculates the stride values from the shape using `prefix_product`. This is the common case for row-major storage order tensors. **Parameters:** * ​idx\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the index. * ​shape\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): IntTuple type of the shape. **Args:** * ​idx ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The linear index to convert. * ​shape ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The shape of the multi-dimensional array. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A `RuntimeTuple` containing the multi-dimensional coordinates calculated using automatically derived strides from the shape.
--- ## runtime_tuple
Provides the `RuntimeTuple` data structure and related utility functions for handling tuple-like data with both compile-time and runtime elements. `RuntimeTuple` is designed for high-performance tensor operations, supporting efficient manipulation of multi-dimensional data structures like shapes, indices, and coordinates. Key features: * Hybrid compile-time/runtime value handling * Optimized for parallel execution and hardware acceleration * Support for nested tuple structures * Efficient conversion between linear indices and multi-dimensional coordinates * Specialized operations for tensor shape calculations The module includes functions for tuple manipulation (concatenation, flattening), coordinate transformations (`idx2crd`, `crd2idx`), and specialized tensor operations like shape division and prefix products. ## Structs * [​`RuntimeTuple`](./RuntimeTuple): A struct representing tuple-like data with compile-time and runtime elements. RuntimeTuple combines static (compile-time) and dynamic (runtime) handling of tuple-like data structures, typically used for tensor shapes, indices, and coordinates in high-performance computing contexts. This struct is optimized for parallel execution and hardware acceleration, allowing efficient manipulation of multi-dimensional data. It supports both known compile-time values and runtime-determined values. ## Functions * [​`coalesce_nested_tuple`](./coalesce_nested_tuple): Coalesces a nested `RuntimeTuple` into a single-level `RuntimeTuple`, by multiplying all the values together. * [​`concat`](./concat): Concatenates two `IntTuple` instances into a single `IntTuple`. * [​`crd2idx`](./crd2idx): Converts multi-dimensional coordinates to a linear index. * [​`idx2crd`](./idx2crd): Converts a linear index to multi-dimensional coordinates. This function transforms a flat index into coordinate values based on the provided shape and stride information. This is essential for mapping linear memory accesses to multi-dimensional tensor elements. * [​`is_int`](./is_int): Determines if a `RuntimeTuple` represents a scalar integer value. * [​`is_tuple`](./is_tuple): Determines if a `RuntimeTuple` represents a tuple rather than a scalar value. * [​`prefix_product`](./prefix_product): Computes the prefix products of elements in the `RuntimeTuple`. * [​`product`](./product): Computes the product of all elements in the `RuntimeTuple`. * [​`shape_div`](./shape_div): Performs specialized shape division between `RuntimeTuple`s. * [​`signum`](./signum): Returns the sign of an integer value. * [​`to_index_list`](./to_index_list): Converts a RuntimeTuple to an IndexList with the same values.
--- ## is_int (Runtime_tuple)
`is_int[t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> Bool` Determines if a `RuntimeTuple` represents a scalar integer value. This function checks if the `RuntimeTuple` holds a single scalar value rather than a tuple structure with multiple elements. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple type parameter of the RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The `RuntimeTuple` to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `RuntimeTuple` represents a scalar integer, False otherwise.
--- ## is_tuple (Runtime_tuple)
`is_tuple[t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> Bool` Determines if a `RuntimeTuple` represents a tuple rather than a scalar value. This function checks the structure of the underlying IntTuple to determine if it represents a tuple with multiple elements or a single scalar value. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple type parameter of the RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The `RuntimeTuple` to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `RuntimeTuple` represents a tuple, False if it represents a scalar.
--- ## prefix_product (Runtime_tuple)
`prefix_product[t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> RuntimeTuple[prefix_product(t)]` Computes the prefix products of elements in the `RuntimeTuple`. This function calculates the running product of elements, where each output element is the product of all previous elements in the input. This is commonly used in tensor computations to calculate stride values. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple type parameter of the input RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The input `RuntimeTuple`. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing the prefix products of the input elements.
--- ## product (Runtime_tuple)
`product[t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> Int` Computes the product of all elements in the `RuntimeTuple`. This function multiplies all scalar values in the tuple, including those in nested tuples after flattening. This is commonly used to calculate the total size of a tensor from its shape. **Parameters:** * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple type parameter of the input RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The input `RuntimeTuple`. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The product of all scalar elements in the tuple.
--- ## shape_div (Runtime_tuple)
`shape_div[a_t: IntTuple, b_t: IntTuple](a: RuntimeTuple[a_t, element_type=element_type], b: RuntimeTuple[b_t, element_type=element_type]) -> RuntimeTuple[shape_div(a_t, b_t)]` Performs specialized shape division between `RuntimeTuple`s. This function implements a special division operation specifically designed for tensor shape calculations. Unlike standard division, it handles special cases: 1. If shapes are directly divisible (a % b == 0), returns a standard division (a // b) 2. If shapes are inversely divisible (b % a == 0), returns the signed reciprocal 3. If shapes are incompatible, aborts with an error This operation is essential for transformations between tensor layouts and computing broadcasting semantics. **Parameters:** * ​a\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the first operand. * ​b\_t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): Type of the second operand. **Args:** * ​a ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The dividend `RuntimeTuple`. * ​b ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The divisor `RuntimeTuple`. **Returns:** [`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple): A new `RuntimeTuple` containing the result of the shape division.
--- ## signum (Runtime_tuple)
`signum(a: Int) -> Int` Returns the sign of an integer value. This helper function determines whether a number is positive, negative, or zero, returning 1 for positive, -1 for negative, and 0 for zero. **Args:** * ​a ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer value to determine the sign of. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): 1 if a > 0, -1 if a < 0, 0 if a == 0.
--- ## to_index_list (Runtime_tuple)
`to_index_list[rank: Int, t: IntTuple](tuple: RuntimeTuple[t, element_type=element_type]) -> IndexList[rank]` Converts a RuntimeTuple to an IndexList with the same values. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the resulting IndexList. * ​t ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple template parameter of the RuntimeTuple. **Args:** * ​tuple ([`RuntimeTuple`](/mojo/kernels/layout/runtime_tuple/RuntimeTuple)): The RuntimeTuple to convert. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): An IndexList filled with the values of the RuntimeTuple.
--- ## ComposedLayout
`struct ComposedLayout[LayoutA: LayoutTrait, LayoutB: LayoutTrait, offset: OptionalReg[Int] = 0]` Layout composed of two layouts applied sequentially. Combines two layouts. Output of the first (`LayoutA`) is input to the second (`LayoutB`), with optional offset in between. ## Parameters * ​LayoutA ([`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait)): The first layout to apply. * ​LayoutB ([`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait)): The second layout to apply. * ​offset ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional offset between layouts (default: 0). ## Fields * ​layout\_a (`LayoutA`): The first layout to apply. * ​layout\_b (`LayoutB`): The second layout to apply. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = LayoutB.__del__is_trivial if LayoutA.__del__is_trivial else LayoutA.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = LayoutB.__moveinit__is_trivial if LayoutA.__moveinit__is_trivial else LayoutA.__moveinit__is_trivial` ### `has_shape` `comptime has_shape = LayoutA.has_shape if LayoutA.has_shape else LayoutB.has_shape` True if either layout has a shape. ## Methods ### `__init__` `__init__(out self, layout_a: LayoutA, layout_b: LayoutB)` Initialize ComposedLayout with two layouts. **Args:** * ​layout\_a (`LayoutA`): The first layout. * ​layout\_b (`LayoutB`): The second layout. ### `__copyinit__` `__copyinit__(out self, other: Self)` Copy constructor for ComposedLayout. **Args:** * ​other (`Self`): The ComposedLayout to copy from. ### `__call__` `__call__(self, idx: IntTuple) -> Int` Apply composed layout to an index. Applies `LayoutA`, then adds offset, then applies `LayoutB`. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The index to transform. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The transformed index. `__call__(self, idx: IntTuple, offset_val: Int) -> Int` Apply composed layout with runtime offset. Applies `LayoutA`, then adds runtime `offset_val`, then `LayoutB`. Static offset must not be set when using runtime offset. **Args:** * ​idx ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The index to transform. * ​offset\_val ([`Int`](/mojo/stdlib/builtin/int/Int)): Runtime offset to apply. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The transformed index. ### `size` `size(self) -> Int` Get the size of the composed layout. Returns the size of the first layout (`LayoutA`). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the first layout. ### `cosize` `cosize(self) -> Int` Get the cosize of the composed layout. Returns the cosize of the second layout (`LayoutB`). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The cosize of the second layout.
--- ## Swizzle
`@register_passable(trivial)` `struct Swizzle` Swizzle functor for memory access pattern optimization. Implements a swizzling pattern to reduce bank conflicts in shared memory accesses. It XORs specific bits of memory indices based on configurable parameters. Swizzle operation: Given index `i`, and Swizzle\[bits, base, shift]: 1. Extract `bits` number of bits from `i` starting from position `base + max(0, shift)`. Let's call this `YYY`. 2. Extract `bits` number of bits from `i` starting from position `base - min(0, shift)`. Let's call this `ZZZ`. 3. Result is `i ^ (YYY shifted by 'shift' positions)`. Example (Swizzle\[2, 0, 3]): Input index bits: `xxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxxx` Output index bits: `xxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxxx` where `AA = ZZ ^ YY`. Attributes: bits (Int): Number of bits in the mask (YYY). base (Int): Number of least significant bits to keep constant. shift (Int): Shift distance for the mask (positive: right, negative: left). yyy\_mask (Int): Mask for the bits to be shifted (YYY). zzz\_mask (Int): Mask for the target bits (ZZZ). ## Fields * ​bits (`Int`): Number of bits in the mask. * ​base (`Int`): Number of least significant bits to keep constant. * ​shift (`Int`): Distance to shift the mask (pos right, neg left). * ​yyy\_mask (`Int`): Mask for the bits to be shifted. * ​zzz\_mask (`Int`): Mask for the target bits. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`LayoutTrait`](/mojo/kernels/layout/layout/LayoutTrait), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `has_shape` `comptime has_shape = False` Indicates if layout has shape. Swizzle always False. ## Methods ### `__init__` `__init__(bits: Int, base: Int, shift: Int) -> Self` Initialize a Swizzle object. Configures the swizzle operation based on bits, base, and shift parameters. **Args:** * ​bits ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of bits in the mask. * ​base ([`Int`](/mojo/stdlib/builtin/int/Int)): Least significant bits to keep constant. * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): Distance to shift the mask. ### `__call__` `__call__(self, index: IntTuple) -> Int` Apply swizzle to an IntTuple index. Unwraps the IntTuple and applies the swizzle to the integer value. **Args:** * ​index ([`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)): The IntTuple index to swizzle. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The swizzled index value. `__call__(self, offset: Int) -> Int` Apply swizzle to an integer offset. Performs the swizzle operation on an integer offset to rearrange memory access patterns. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer offset to swizzle. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The swizzled offset value. `__call__(self, offset: Scalar[dtype]) -> Scalar[dtype]` Apply swizzle to a scalar offset. Scalar version of the swizzle operation. Applies swizzle to a scalar offset. **Args:** * ​offset ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar offset to swizzle. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The swizzled scalar value. ### `size` `size(self) -> Int` Get the size of the swizzle pattern. Calculates the size of the memory region affected by the swizzle pattern. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the swizzle pattern. ### `cosize` `cosize(self) -> Int` Get the cosize of the swizzle pattern. Cosize is the same as size for swizzle layouts, representing the output size. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The cosize of the swizzle pattern (same as size). ### `write_to` `write_to(self, mut writer: T)` Write the swizzle parameters to a writer. Outputs the swizzle parameters (bits, base, shift) in a tuple format. **Args:** * ​writer (`T`): The writer to write to. ### `__str__` `__str__(self) -> String` Convert the swizzle to a string representation. **Returns:** `String`: String representation of the swizzle parameters.
--- ## eval_composed
`eval_composed[composed_layout: ComposedLayout[Layout, Swizzle]](idx: UInt, offset: UInt = 0) -> UInt` Evaluate a composed layout with swizzle. Evaluates a `ComposedLayout[Layout, Swizzle]`. Applies the base layout, adds an optional offset, and then applies the swizzle. **Parameters:** * ​composed\_layout ([`ComposedLayout`](/mojo/kernels/layout/swizzle/ComposedLayout)): The composed layout to evaluate, consisting of a base Layout and a Swizzle transformation. **Args:** * ​idx ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The input index to transform. * ​offset ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Optional offset to apply between layouts (default: 0). **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): The transformed index after applying both layouts.
--- ## swizzle (Swizzle)
Defines swizzle layouts for optimizing memory access patterns. This module is designed for use in shared memory, especially in GPU kernels, to reduce bank conflicts. It provides tools to create and apply swizzle transformations to memory indices. Swizzling rearranges memory access order to distribute accesses across different memory banks. This mitigates bank contention and improves memory access efficiency. Module components: * `Swizzle` struct: Represents a swizzle transformation with configurable bits, base, and shift parameters. * Helper functions: `make_ldmatrix_swizzle`, `make_swizzle` create predefined swizzle patterns. These are optimized for scenarios like `ldmatrix` instructions and general 2D memory access. * `ComposedLayout` struct: Combines a base layout with a swizzle layout for complex memory access optimizations. Primary use case: GPU kernel development where shared memory bank conflicts can degrade performance. Applying swizzle layouts optimizes memory access patterns for higher throughput. ## Structs * [​`ComposedLayout`](./ComposedLayout): Layout composed of two layouts applied sequentially. * [​`Swizzle`](./Swizzle): Swizzle functor for memory access pattern optimization. ## Functions * [​`eval_composed`](./eval_composed): Evaluate a composed layout with swizzle. * [​`make_ldmatrix_swizzle`](./make_ldmatrix_swizzle): Make swizzle to avoid bank conflict for ldmatrix ops. * [​`make_swizzle`](./make_swizzle): Create a 2D swizzle to avoid bank conflicts. * [​`shiftl`](./shiftl): Shift left or right based on sign of shift amount. * [​`shiftr`](./shiftr): Shift right or left based on sign of shift amount.
--- ## make_ldmatrix_swizzle
`make_ldmatrix_swizzle[dtype: DType, row_size: Int, log2_vector_width: Int = 0]() -> Swizzle` Make swizzle to avoid bank conflict for ldmatrix ops. Creates a swizzle pattern optimized for `ldmatrix` operations. Minimizes bank conflicts in shared memory for these operations. Calculates swizzle parameters based on data type and row size. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements. * ​row\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of each row in elements. * ​log2\_vector\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): Log2 of the vector width (default: 0). **Returns:** `Swizzle`: A `Swizzle` object configured for `ldmatrix`.
--- ## make_swizzle
`make_swizzle[num_rows: Int, row_size: Int, access_size: Int]() -> Swizzle` Create a 2D swizzle to avoid bank conflicts. Generates a swizzle pattern for 2D memory layout to minimize bank conflicts in shared memory access. **Parameters:** * ​num\_rows ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the minimum access pattern. * ​row\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of each row in elements. * ​access\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements accessed at once. **Returns:** `Swizzle`: A `Swizzle` object for 2D memory access. `make_swizzle[dtype: DType, mode: TensorMapSwizzle]() -> Swizzle` Create swizzle based on predefined swizzle modes. Returns a swizzle pattern based on standard modes (32B, 64B, 128B, none), adjusted for data type. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements. * ​mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzle mode to use (TensorMapSwizzle enum). **Returns:** `Swizzle`: A `Swizzle` object configured by the specified mode.
--- ## shiftl
`shiftl(a: Int, s: Int) -> Int` Shift left or right based on sign of shift amount. Performs a left shift if `s` is positive, or a right shift if `s` is negative. **Args:** * ​a ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer value to shift. * ​s ([`Int`](/mojo/stdlib/builtin/int/Int)): The shift amount. Positive for left, negative for right. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The shifted integer value. `shiftl(a: Scalar[dtype], s: Scalar[dtype]) -> Scalar[dtype]` Shift left/right based on sign of shift for scalars. Scalar version of `shiftl`. Left shift if `s` is positive, right shift if `s` is negative. **Args:** * ​a ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to shift. * ​s ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar shift amount. Positive for left, negative right. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The shifted scalar value.
--- ## shiftr
`shiftr(a: Int, s: Int) -> Int` Shift right or left based on sign of shift amount. Performs a right shift if `s` is positive, or a left shift if `s` is negative. **Args:** * ​a ([`Int`](/mojo/stdlib/builtin/int/Int)): The integer value to shift. * ​s ([`Int`](/mojo/stdlib/builtin/int/Int)): The shift amount. Positive for right, negative for left. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The shifted integer value. `shiftr(a: Scalar[dtype], s: Scalar[dtype]) -> Scalar[dtype]` Shift right/left based on sign of shift for scalars. Scalar version of `shiftr`. Right shift if `s` is positive, left shift if `s` is negative. **Args:** * ​a ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value to shift. * ​s ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar shift amount. Positive for right, negative left. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The shifted scalar value.
--- ## TensorCore
`struct TensorCore[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool = False]` TensorCore provides an abstraction for GPU tensor core hardware to perform optimized matrix operations. This struct encapsulates the functionality required to efficiently map matrix operations to Tensor Cores on NVIDIA and AMD GPUs. It handles loading matrix fragments, performing matrix multiply-accumulate operations, and storing results with hardware-specific optimizations. Note: Different shapes and data types are supported depending on the GPU hardware. For NVIDIA GPUs: * float32: 16x8x8 or 16x8x4 * half-precision: 16x8x16 * float8: 16x8x32 For AMD GPUs: * float32: 16x16x4 * half-precision: 16x16x16 or 32x32x8 ## Parameters * ​out\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for output/accumulation operations. * ​in\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for input matrix elements. * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape parameters for the matrix operation in the form \[M, N, K] where MxN is the output shape and K is the inner dimension. * ​transpose\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to transpose the B matrix before multiplication. Defaults to False. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_reg_type` `comptime a_reg_type = SIMD[in_type, num_matrix_reg[shape.__getitem__[3, DType.int64, Int](0), shape.__getitem__[3, DType.int64, Int](2)]()]` SIMD type for the A operand registers. ### `b_reg_type` `comptime b_reg_type = SIMD[in_type, num_matrix_reg[shape.__getitem__[3, DType.int64, Int](2), shape.__getitem__[3, DType.int64, Int](1)]()]` SIMD type for the B operand registers. ### `c_reg_tile_type` `comptime c_reg_tile_type = LayoutTensor[out_type, Layout.col_major(1, num_matrix_reg[shape.__getitem__[3, DType.int64, Int](0), shape.__getitem__[3, DType.int64, Int](1)]()), MutAnyOrigin, address_space=AddressSpace.LOCAL]` LayoutTensor type for the C register tile. ### `c_reg_type` `comptime c_reg_type = SIMD[out_type, num_matrix_reg[shape.__getitem__[3, DType.int64, Int](0), shape.__getitem__[3, DType.int64, Int](1)]()]` SIMD type for the C/accumulator operand registers. ### `supported_fp32` `comptime supported_fp32 = (shape == IndexList[3, DType.int64](16, 8, 8, Tuple[]())) if is_nvidia_gpu() else (shape == IndexList[3, DType.int64](16, 16, 4, Tuple[]())) if (#pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@dtype::@DType> in_type, "_mlir_value">> == 81) else (in_type is DType.float32)` Whether float32 is supported for this tensor core configuration. ### `supported_fp64` `comptime supported_fp64 = Tuple[IndexList[3], IndexList[3], IndexList[3], IndexList[3]](VariadicPack[True, True, origin_of(), Movable, IndexList[3], IndexList[3], IndexList[3], IndexList[3]](shape_8x8x4, shape_16x8x4, shape_16x8x8, shape_16x8x16)).__contains__[IndexList[3], IndexList[3], IndexList[3], IndexList[3], IndexList[3]](shape) if (#pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@dtype::@DType> out_type, "_mlir_value">> == 82) if (#pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@dtype::@DType> in_type, "_mlir_value">> == 82) else (#pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@dtype::@DType> in_type, "_mlir_value">> == 82) else (out_type is DType.float64) if (#pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@dtype::@DType> in_type, "_mlir_value">> == 82) else (in_type is DType.float64) if is_nvidia_gpu() else False` Whether float64 is supported for this tensor core configuration. ### `supported_fp8` `comptime supported_fp8 = (shape == shape_16x8x32) if Tuple[DType, DType](VariadicPack[True, True, origin_of(), Movable, DType, DType](DType.float8_e4m3fn, DType.float8_e5m2)).__contains__[DType, DType, DType](in_type) else Tuple[DType, DType](VariadicPack[True, True, origin_of(), Movable, DType, DType](DType.float8_e4m3fn, DType.float8_e5m2)).__contains__[DType, DType, DType](in_type) if is_nvidia_gpu() else (shape == shape_16x16x32) if Tuple[DType, DType](VariadicPack[True, True, origin_of(), Movable, DType, DType](get_amd_fp8_dtype(), get_amd_bf8_dtype())).__contains__[DType, DType, DType](in_type) else Tuple[DType, DType](VariadicPack[True, True, origin_of(), Movable, DType, DType](get_amd_fp8_dtype(), get_amd_bf8_dtype())).__contains__[DType, DType, DType](in_type)` Whether float8 is supported for this tensor core configuration. ### `supported_half` `comptime supported_half = (shape == shape_16x8x16) if is_nvidia_gpu() else Tuple[IndexList[3], IndexList[3], IndexList[3], IndexList[3]](VariadicPack[True, True, origin_of(), Movable, IndexList[3], IndexList[3], IndexList[3], IndexList[3]](shape_16x16x16, shape_16x16x32, shape_32x32x8, shape_32x32x16)).__contains__[IndexList[3], IndexList[3], IndexList[3], IndexList[3], IndexList[3]](shape) if (#pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@dtype::@DType> in_type, "_mlir_value">> == 80) if (#pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@dtype::@DType> in_type, "_mlir_value">> == 80) else (#pop.dtype_to_ui8<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@dtype::@DType> in_type, "_mlir_value">> == 79) else in_type.is_half_float()` Whether half-precision float is supported for this configuration. ## Methods ### `__init__` `__init__(out self)` Initialize a new TensorCore instance. ### `get_shapes` `static get_shapes[_out_type: DType, _in_type: DType]() -> List[IndexList[3]]` Get supported shapes for given data types. Returns a list of valid shapes for the specified output and input data types. Note: The returned shapes are hardware-dependent. Different shapes are supported for different combinations of input and output types. **Parameters:** * ​\_out\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The output/accumulation data type. * ​\_in\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The input matrix data type. **Returns:** [`List`](/mojo/stdlib/collections/list/List): List\[IndexList\[3]]: Valid shapes for the matrix operations given the specified types. ### `load_a` `load_a[swizzle: OptionalReg[Swizzle] = None](self, a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_a_reg_tile_layout[layout, shape](), MutAnyOrigin, address_space=AddressSpace.LOCAL]` Load the A matrix fragments. Loads matrix A from memory into a LayoutTensor suitable for tensor core operations. **Parameters:** * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzle pattern for optimal memory access (AMD only). **Args:** * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source matrix A data. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The loaded matrix fragments as a `LayoutTensor`. `load_a[swizzle: OptionalReg[Swizzle] = None](self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = 0)` Load A matrix fragments from shared memory. Optimized version for loading A matrix fragments from shared memory. **Parameters:** * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional memory access pattern for to optimize memory bandwidth. **Args:** * ​warp\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source data in shared memory. * ​fragments ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor for fragments. * ​mma\_tile\_coord\_k ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The K coordinate of the MMA tile. Defaults to 0. ### `load_b` `load_b[swizzle: OptionalReg[Swizzle] = None](self, b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[in_type, _get_b_reg_tile_layout[layout, shape, transpose_b](), MutAnyOrigin, address_space=AddressSpace.LOCAL]` Load the B matrix fragments. Loads matrix B from memory into a `LayoutTensor` suitable for tensor core operations. The function handles different hardware architectures and memory access patterns. Note: If transpose\_b is `True`, the B matrix will be transposed during loading. This is more efficient than transposing the matrix in memory. **Parameters:** * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzle pattern for optimal memory access (AMD only). Will cause an error if used with NVIDIA GPUs. **Args:** * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source matrix B data. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor): The loaded matrix fragments as a `LayoutTensor`. `load_b[swizzle: OptionalReg[Swizzle] = None](self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = 0, warp_tile_coord_n: UInt = 0)` Load B matrix fragments from shared memory into registers for tensor core operations. This function loads matrix B fragments from a warp tile in shared memory into register fragments for use in tensor core matrix multiply operations. It handles hardware-specific optimizations for both NVIDIA and AMD GPUs. Note: The `warp_tile` must be in shared memory. For NVIDIA GPUs, `swizzle` must be `None`. For AMD GPUs, providing an appropriate `swizzle` pattern can improve performance. **Parameters:** * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional memory access pattern for AMD GPUs to optimize memory bandwidth. Must be None when running on NVIDIA GPUs. For NVIDIA GPUs, swizzle is always on. **Args:** * ​warp\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source `LayoutTensor` in shared memory containing the B matrix data. * ​fragments ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination `LayoutTensor` to store the loaded matrix fragments. * ​mma\_tile\_coord\_k ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): K-dimension coordinate within the warp tile. Defaults to 0. * ​warp\_tile\_coord\_n ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): N-dimension coordinate within the warp tile. Defaults to 0. `load_b(self, warp_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], fragments: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scales: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_tile_coord_k: UInt = 0)` Load quantized B matrix fragments from shared memory with dequantization. This function loads int4 quantized matrix B fragments from shared memory, dequantizes them using the provided scales, and stores the result in register fragments for tensor core operations. Notes: * The `warp_tile` must be in shared memory. * The `fragments` and `scales` must be in local memory. * This function only supports half-precision data types (bfloat16, float16). * The quantized data is stored as int4 values packed into int32 elements. * Each thread processes multiple fragments by unpacking and dequantizing the int4 values. **Args:** * ​warp\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source `LayoutTensor` in shared memory containing the quantized B matrix data. * ​fragments ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination `LayoutTensor` to store the dequantized matrix fragments. * ​scales ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): `LayoutTensor` containing the scaling factors for dequantization. * ​mma\_tile\_coord\_k ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): K-dimension coordinate within the warp tile. Defaults to 0. ### `load_c` `load_c(self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> TensorCore[out_type, in_type, shape, transpose_b].c_reg_tile_type` Load the C matrix fragments. Loads matrix C from memory into a `LayoutTensor` suitable for tensor core operations. The function handles different hardware architectures and memory access patterns. **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source matrix C data. **Returns:** `TensorCore`: The loaded matrix fragments as a `LayoutTensor`. ### `store_d` `store_d(self, d_dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], d_src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Store matrix D to destination memory. Stores the result matrix D from tensor core computation to the destination memory. **Args:** * ​d\_dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor to store the result. * ​d\_src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor containing the computed result. ### `mma_op` `mma_op(self, a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> TensorCore[out_type, in_type, shape, transpose_b].c_reg_tile_type` Perform matrix multiply-accumulate operation (MMA). Executes `D = A * B + C` using tensor cores. **Args:** * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The A matrix input. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The B matrix input. * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The C matrix input for accumulation. **Returns:** `TensorCore`: `Self.c_reg_tile_type`: The result of the MMA operation. ### `mma` `mma(self, a_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_frag: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Perform matrix multiply-accumulate operation using tensor cores. Executes C = A \* B + C using tensor cores, where A, B, and C are matrix fragments stored in register memory. This function handles the mapping of fragments to hardware tensor core operations. Notes: * All fragments must be properly loaded using the corresponding load functions. * The function assumes fragments are vectorized layout tensors with dimensions num\_vectors x 1. * The c\_frag shape\[0] must equal num\_m\_mmas \* num\_n\_mmas. * The result is accumulated in-place in c\_frag. **Args:** * ​a\_frag ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix A fragments as a `LayoutTensor`. * ​b\_frag ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix B fragments as a `LayoutTensor`. * ​c\_frag ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix C fragments as a `LayoutTensor` for both input and output.
--- ## TiledTensorCore
`struct TiledTensorCore[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool = False]` TiledTensorCore provides a wrapper around TensorCore to support multiple MMAs along the K dimension. Enables larger K dimension operations by decomposing them into multiple smaller MMA operations. Currently only being used for AMD GPUs to enable 16x16x32 operations using two 16x16x16 MMAs. ## Parameters * ​out\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for output/accumulation operations. * ​in\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for input matrix elements. * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape parameters for individual MMA operations \[M, N, K]. * ​group\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of MMA operations along the K dimension. * ​transpose\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to transpose the b matrix. Defaults to False. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `mma_op` `comptime mma_op = TensorCore[out_type, in_type, shape, transpose_b]()` The underlying TensorCore instance for MMA operations. ## Methods ### `mma` `static mma[swap_a_b: Bool = False](a_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Perform multiple matrix multiply-accumulate operations along the K dimension. Executes group\_size MMA operations, processing slices of the K dimension and accumulating results in c\_reg\_tile. **Parameters:** * ​swap\_a\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to swap a and b operands. Defaults to False. **Args:** * ​a\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix a fragments \[num\_m\_mmas, group\_size \* a\_frag\_size]. * ​b\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix b fragments \[num\_n\_mmas, group\_size \* b\_frag\_size]. * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Accumulation matrix c fragments, modified in-place.
--- ## get_fragment_size
`get_fragment_size[mma_shape: IndexList[3]]() -> IndexList[3]` Calculates the fragment size per thread for a given MMA shape. For tensor core operations, each thread in a warp handles a portion of the computation. This function determines how many elements each thread needs to process for the A, B, and C/D matrices based on the MMA shape. **Parameters:** * ​mma\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): An `IndexList[3]` containing the MMA dimensions \[M, N, K]. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): An `IndexList[3]` containing the fragment sizes per thread for matrices A, B, and C/D respectively, calculated as: `[M*K/WARP_SIZE, N*K/WARP_SIZE, M*N/WARP_SIZE]`.
--- ## get_mma_shape
`get_mma_shape[input_type: DType, accum_type: DType, shape_id: Int = 0]() -> IndexList[3]` Returns the appropriate matrix multiply-accumulate (MMA) shape for tensor core operations. Selects the optimal MMA shape based on the GPU architecture, input data type, accumulation data type, and optional shape identifier. This function handles different configurations for both NVIDIA and AMD GPUs. **Parameters:** * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the input matrices (A and B). * ​accum\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type used for accumulation (C and D). * ​shape\_id ([`Int`](/mojo/stdlib/builtin/int/Int)): Optional identifier to select between multiple valid shapes (default: 0). **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): An `IndexList[3]` containing the MMA dimensions in the format `[M, N, K]`, where `MxN` is the output matrix size and `K` is the reduction dimension.
--- ## tensor_core
Tensor Core Module for High-Performance Matrix Operations Provides abstractions for using GPU Tensor Cores to perform optimized matrix operations. It supports both NVIDIA and AMD GPU architectures with hardware-specific optimizations. ## Key Components: * `TensorCore`: Core struct that encapsulates tensor core operations with support for various data types and matrix shapes. It handles loading matrix fragments, performing matrix multiply-accumulate operations, and storing results. * Matrix Fragment Management: Functions for loading and storing matrix fragments to/from shared memory with hardware-specific optimizations. * Matrix Multiply-Accumulate (MMA): Optimized implementations of matrix multiplication operations using tensor cores. ## Supported Operations: * Matrix loading with various layouts and swizzling patterns * Matrix multiply-accumulate (D = A \* B + C) * Matrix storing with hardware-specific optimizations ## Supported Data Types: * NVIDIA: float32, bfloat16, float16, float8\_e4m3fn, float8\_e5m2 * AMD: float32, bfloat16, float16 ## Supported Matrix Shapes: * NVIDIA: 16x8x8, 16x8x4, 16x8x16, 8x8x4, 16x8x32 * AMD: 16x16x4, 16x16x16, 32x32x8 ## `comptime` values ### `shape_16x16x16` `comptime shape_16x16x16 = IndexList[3, DType.int64](16, 16, 16, Tuple[]())` AMDGPU tensor core shape 16x16x16. ### `shape_16x16x32` `comptime shape_16x16x32 = IndexList[3, DType.int64](16, 16, 32, Tuple[]())` AMDGPU tensor core shape 16x16x32. ### `shape_16x16x4` `comptime shape_16x16x4 = IndexList[3, DType.int64](16, 16, 4, Tuple[]())` AMDGPU tensor core shape 16x16x4. ### `shape_16x8x16` `comptime shape_16x8x16 = IndexList[3, DType.int64](16, 8, 16, Tuple[]())` Tensor core shape 16x8x16. ### `shape_16x8x32` `comptime shape_16x8x32 = IndexList[3, DType.int64](16, 8, 32, Tuple[]())` Tensor core shape 16x8x32. ### `shape_16x8x4` `comptime shape_16x8x4 = IndexList[3, DType.int64](16, 8, 4, Tuple[]())` Tensor core shape 16x8x4. ### `shape_16x8x8` `comptime shape_16x8x8 = IndexList[3, DType.int64](16, 8, 8, Tuple[]())` Tensor core shape 16x8x8. ### `shape_32x32x16` `comptime shape_32x32x16 = IndexList[3, DType.int64](32, 32, 16, Tuple[]())` AMDGPU tensor core shape 32x32x16. ### `shape_32x32x8` `comptime shape_32x32x8 = IndexList[3, DType.int64](32, 32, 8, Tuple[]())` AMDGPU tensor core shape 32x32x8. ### `shape_8x8x4` `comptime shape_8x8x4 = IndexList[3, DType.int64](8, 8, 4, Tuple[]())` Tensor core shape 8x8x4. ### `shape_null` `comptime shape_null = IndexList[3, DType.int64](0, 0, 0, Tuple[]())` Null tensor core shape (0x0x0). ## Structs * [​`TensorCore`](./TensorCore): TensorCore provides an abstraction for GPU tensor core hardware to perform optimized matrix operations. * [​`TiledTensorCore`](./TiledTensorCore): TiledTensorCore provides a wrapper around TensorCore to support multiple MMAs along the K dimension. ## Functions * [​`get_fragment_size`](./get_fragment_size): Calculates the fragment size per thread for a given MMA shape. * [​`get_mma_shape`](./get_mma_shape): Returns the appropriate matrix multiply-accumulate (MMA) shape for tensor core operations. * [​`load_b_nt`](./load_b_nt): Loads the b operand tile for AMD tensor core MFMA from (N, K) storage. * [​`load_b_tr`](./load_b_tr): Loads the b operand tile for AMD tensor core MFMA instructions using transposed memory access. * [​`num_matrix_reg`](./num_matrix_reg): Calculates the number of matrix registers required per thread.
--- ## load_b_nt
`load_b_nt[mma_shape: IndexList[3], swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]()](tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> SIMD[dtype, 8]` Loads the b operand tile for AMD tensor core MFMA from (N, K) storage. This function supports double-rate MFMA shapes (32x32x16, 16x16x32) with bfloat16 input. Unlike load\_b\_tr which expects (K, N) storage, this function works with (N, K) storage which is common when transpose\_b=True and B is stored row-major. The input tile (shape = (mma\_shape\[1], mma\_shape\[2])) is split along the K dimension into two halves of shape (MMA\_N, MMA\_K//2). Each half is loaded using `_load_tr16_b64_warp`, which performs a transposed (column-major) load from shared memory. The hardware transpose effectively converts the (N, K) storage to (K, N) format needed by MMA. Example: For 16x16x32 MMA with B stored as (N, K) = (16, 32) in LDS: ```mojo # B tile in LDS: shape (16, 32) = (MMA_N, MMA_K) var b_tile = smem_b.tile[16, 32](n_idx, k_idx) var b_reg = load_b_nt[IndexList[3](16, 16, 32)](b_tile) # b_reg now contains 8 bf16 values ready for MFMA ``` **Parameters:** * ​mma\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The MMA instruction tile shape (only 32x32x16 or 16x16x32 supported). * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzle pattern for bank-conflict-free LDS access. **Args:** * ​tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A `LayoutTensor`, residing in shared memory, with shape (mma\_shape\[1], mma\_shape\[2]) and dtype `DType.bfloat16`. This is (N, K) storage order. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD\[tile.dtype, 8]: Concatenated transposed SIMD loads from both halves of the tile.
--- ## load_b_tr
`load_b_tr[mma_shape: IndexList[3], swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]()](tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> SIMD[dtype, 8]` Loads the b operand tile for AMD tensor core MFMA instructions using transposed memory access. This function supports double-rate MFMA shapes (32x32x16, 16x16x32) with bfloat16 input. The input tile (shape = (mma\_shape\[2], mma\_shape\[1])) is split along the K dimension into two halves of shape (MMA\_K//2, MMA\_N). Each half is loaded using `_load_tr16_b64_warp`, which performs a transposed (column-major) load from shared memory. The resulting two 4-element SIMD vectors are concatenated into a single `SIMD[tile.dtype, 8]` vector. **Parameters:** * ​mma\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The MMA instruction tile shape (only 32x32x16 or 16x16x32 supported). * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzle pattern for bank-conflict-free LDS access. **Args:** * ​tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A `LayoutTensor`, residing in shared memory, with shape (mma\_shape\[2], mma\_shape\[1]) and dtype `DType.bfloat16`. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD\[tile.dtype, 8]: Concatenated transposed SIMD loads from both halves of the tile.
--- ## num_matrix_reg
`num_matrix_reg[dim_1: Int, dim_2: Int]() -> Int` Calculates the number of matrix registers required per thread. Determines how many registers each thread in a warp needs to store a matrix of the given dimensions. This is calculated by dividing the total number of elements (dim\_1 \* dim\_2) by the warp size, as the matrix is distributed across all threads in the warp. **Parameters:** * ​dim\_1 ([`Int`](/mojo/stdlib/builtin/int/Int)): First dimension of the matrix. * ​dim\_2 ([`Int`](/mojo/stdlib/builtin/int/Int)): Second dimension of the matrix. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of matrix registers needed per thread.
--- ## TensorCoreAsync
`struct TensorCoreAsync[c_type: DType, a_type: DType, b_type: DType, mma_shape: IndexList[3], /, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, transpose_b: Bool = False]` High-performance asynchronous tensor core operations for matrix multiplication. This struct provides methods for utilizing NVIDIA's Tensor Cores for asynchronous matrix multiplication operations, with support for various data types and swizzling configurations. ## Parameters * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the output matrix C. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the input matrix A. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the input matrix B. * ​mma\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Dimensions for the matrix multiply-accumulate (MMA) operation as \[M, N, K]. * ​a\_swizzle ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Swizzling mode for matrix A (default: SWIZZLE\_NONE). * ​b\_swizzle ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Swizzling mode for matrix B (default: SWIZZLE\_NONE). * ​transpose\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to transpose matrix B (default: False). ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self)` Initialize the `TensorCoreAsync` instance. Ensures that the provided MMA shape is supported. Note: Fails to compile if `mma_shape` is not supported. ### `wgmma` `static wgmma[num_warp_groups: Int = 1, scale_c: Int = 1, scale_a: Int = 1, scale_b: Int = 1, num_k_iters: OptionalReg[Int] = None](a_smem_tile: LayoutTensor[a_type, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[b_type, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[c_type, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], wg_idx: Int = 0)` Perform asynchronous matrix multiplication using warp group matrix multiply-accumulate (WGMMA). This method handles the case where both A and B matrices are in shared memory. **Parameters:** * ​num\_warp\_groups ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of warp groups to distribute work across (default: 1). * ​scale\_c ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix C. Valid values are 1 or 0 (default: 1). * ​scale\_a ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix A. Valid values are 1 or -1 (default: 1). * ​scale\_b ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix B. Valid values are 1 or -1 (default: 1). * ​num\_k\_iters ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Number of iterations for the K dimension. This is useful to save computation when we pad shared memory. (default: None which is just `a_smem_layout[1].size() // mma_shape[2]`). **Args:** * ​a\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix A in shared memory. * ​b\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix B in shared memory. * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix C in register memory. * ​wg\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Warp group index for multi-warp group scenarios (default: 0). `static wgmma(a_frag_tile: LayoutTensor[a_type, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[b_type, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[c_type, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Perform asynchronous matrix multiplication using warp group matrix multiply-accumulate (WGMMA). This overloaded method handles the case where matrix A is in register memory and matrix B is in shared memory. **Args:** * ​a\_frag\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix A in register memory. * ​b\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Matrix B in shared memory. * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix C in register memory. ### `arrive` `static arrive()` Ensures memory consistency by creating a fence for WGMMA operations. This method should be called before committing a group to ensure all shared memory accesses are properly aligned and visible. ### `commit_group` `static commit_group()` Commits the current warp group for execution. This synchronizes the warp group and commits all pending WGMMA operations that have been previously issued. ### `wait_group` `static wait_group[group: Int = 0]()` Waits for the completion of a specific warp group's operations. This method blocks until all WGMMA operations from the specified group are complete. **Parameters:** * ​group ([`Int`](/mojo/stdlib/builtin/int/Int)): The group ID to wait for (default: 0).
--- ## tensor_core_async
Tensor Core Async Module This module provides high-performance abstractions for utilizing NVIDIA's Tensor Cores to perform asynchronous matrix multiplication operations. It implements optimized memory layouts and access patterns for efficient tensor core computations. Key components: * Layout creation functions for K-major and MN-major memory arrangements * Swizzling support for improved memory access patterns * WGMMA (Warp Group Matrix Multiply-Accumulate) descriptor generation * TensorCoreAsync struct with methods for asynchronous matrix multiplication The module supports various data types, matrix dimensions, and memory configurations, enabling efficient implementation of deep learning primitives and other tensor operations that can leverage hardware acceleration. Performance features: * Asynchronous execution model to overlap computation and memory access * Support for different swizzling modes to optimize memory bandwidth * Efficient register and shared memory utilization * Support for multi-warp group execution This implementation is specifically optimized for NVIDIA GPUs with Tensor Core support. ## `comptime` values ### `WGMMA_K_BYTES` `comptime WGMMA_K_BYTES = 32` Size of WGMMA K dimension in bytes. ## Structs * [​`TensorCoreAsync`](./TensorCoreAsync): High-performance asynchronous tensor core operations for matrix multiplication. ## Functions * [​`select_k_atom`](./select_k_atom): Creates a core matrix layout for tensor core operations. * [​`st_matrix_n_atom`](./st_matrix_n_atom): Creates a layout for N-major `st_matrix` atom in the context of WGMMA C matrix. * [​`st_matrix_n_layout`](./st_matrix_n_layout): Creates a layout for N-major `st_matrix` in the context of WGMMA C matrix. * [​`tile_layout_k_major`](./tile_layout_k_major): Creates a K-major layout for tensor core operations. * [​`tile_layout_mn_major`](./tile_layout_mn_major): Creates an MN-major layout for tensor core operations. * [​`tile_sf_layout_k_major`](./tile_sf_layout_k_major): Creates a K-major layout for tensor core scale factors. * [​`tile_to_descriptor`](./tile_to_descriptor): Transforms a layout into a WGMMA descriptor-compatible layout. * [​`warpgroup_fence`](./warpgroup_fence): Code motion fence to ensure the registers of the WGMMA instruction do not get touched by anything. * [​`wgmma_c_layout`](./wgmma_c_layout): Generates three layouts for mapping WGMMA C matrix coordinates. * [​`wgmma_c_thread_layout`](./wgmma_c_thread_layout): Returns the thread layout component for WGMMA C matrix. * [​`wgmma_output_layout`](./wgmma_output_layout): Returns the output layout component for WGMMA C matrix.
--- ## select_k_atom
`select_k_atom[dtype: DType, swizzle_mode: TensorMapSwizzle]() -> Layout` Creates a core matrix layout for tensor core operations. Constructs the fundamental atomic layout for tensor core operations based on the specified data type and swizzle mode. This layout represents the minimal dense matrix structure that can be efficiently processed by tensor cores. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element data type of the tensor. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern swizzling mode. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A core matrix layout optimized for tensor core operations.
--- ## st_matrix_n_atom
`st_matrix_n_atom[num_stmatrix: Int]() -> Layout` Creates a layout for N-major `st_matrix` atom in the context of WGMMA C matrix. The domain of this layout is the warp group local thread index. Thus, the layout takes \[0, 128) as input and returns an offset for a logical array with an element size of 128-bit. **Parameters:** * ​num\_stmatrix ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of N-dimension tiles in the C matrix. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout that maps warp group local thread index to an offset for a logical array with an element size of 128-bit.
--- ## st_matrix_n_layout
`st_matrix_n_layout[c_type: DType, WG_BN: Int, num_m_mmas: Int, num_consumer: Int]() -> Layout` Creates a layout for N-major `st_matrix` in the context of WGMMA C matrix. The layout modes are: the warp group local thread index, the N-dimension tiling size `WG_BN // 16`, the number of MMA tiles `num_m_mmas` in the M-dimension, and the number of consumers `num_consumer`. The output is an offset for a logical array with the element type `c_type`. **Parameters:** * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the C matrix. * ​WG\_BN ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the K dimension in the C matrix in shared memory. * ​num\_m\_mmas ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of MMA tiles in the M dimension. * ​num\_consumer ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of consumers. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout that maps warp group local thread index to an offset for a logical array with the element type `c_type`.
--- ## tile_layout_k_major
`tile_layout_k_major[dtype: DType, BM: Int, BK: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE]() -> Layout` Creates a K-major layout for tensor core operations. Constructs a layout optimized for K-major access patterns in tensor core operations, with optional swizzling for improved memory access patterns. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element data type of the tensor. * ​BM ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the M dimension in the tile. * ​BK ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the K dimension in the tile. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern swizzling mode (default: SWIZZLE\_NONE). **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A K-major layout configured for the specified dimensions and swizzle mode.
--- ## tile_layout_mn_major
`tile_layout_mn_major[dtype: DType, mn_dim: Int, k_dim: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE]() -> Layout` Creates an MN-major layout for tensor core operations. Constructs a unit layout optimized for MN-major access patterns in shared memory, with optional swizzling for improved memory access patterns. Note: This returns the "unit" layout; the actual shared memory layout can be a multiple of this unit. Currently only supports SWIZZLE\_NONE and SWIZZLE\_128B modes. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element data type of the tensor. * ​mn\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the MN dimension. * ​k\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the K dimension. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern swizzling mode (default: SWIZZLE\_NONE). **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - An MN-major layout configured for the specified dimensions and swizzle mode.
--- ## tile_sf_layout_k_major
`tile_sf_layout_k_major[BM: Int, BK: Int, SF_SCALE_SIZE: Int]() -> Layout` Creates a K-major layout for tensor core scale factors. Constructs a layout for K-major access patterns for scale factors. **Parameters:** * ​BM ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the M dimension in the tile. * ​BK ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the K dimension in the tile. * ​SF\_SCALE\_SIZE ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements in a scale factor vector. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A K-major layout configured for the specified dimensions and scale factor size.
--- ## tile_to_descriptor
`tile_to_descriptor[dtype: DType, layout: Layout, is_k_major: Bool = True]() -> Layout` Transforms a layout into a WGMMA descriptor-compatible layout. Converts a standard layout into a form that can be used with WGMMA descriptors, handling both K-major and MN-major layouts differently. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element data type of the tensor. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Input layout to transform. * ​is\_k\_major ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the layout is K-major (True) or MN-major (False). **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): \`Layout - A transformed layout compatible with WGMMA descriptors.
--- ## warpgroup_fence
`warpgroup_fence[accum_type: DType, accum_layout: Layout, //](accum: LayoutTensor[accum_type, accum_layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Code motion fence to ensure the registers of the WGMMA instruction do not get touched by anything. This has no impact on kernel correctness. It serves purely as an NVVM code motion barrier, preventing other operations from modifying the WGMMA instruction's registers during execution of the WGMMA instruction batch. **Parameters:** * ​accum\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element data type of the tensor. * ​accum\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Register layout of the accumulator. **Args:** * ​accum ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A LayoutTensor with the accum\_type and accum\_layout.
--- ## wgmma_c_layout
`wgmma_c_layout[mma_m: Int, mma_n: Int, C: Layout]() -> List[Layout]` Generates three layouts for mapping WGMMA C matrix coordinates. This function creates three layout mappings that are essential for working with WGMMA (Warp Group Matrix Multiply-Accumulate) operations: 1. A projection layout that maps linearized indices to row coordinates (i) 2. A projection layout that maps linearized indices to column coordinates (j) 3. A composite layout that maps thread and vector coordinates to linearized indices across multiple MMA tiles These layouts are particularly useful for operations like attention masking and matrix multiplication epilogues, where register values need to be mapped to the coordinate system of the C matrix. Note: This function enforces constraints on the WGMMA dimensions and ensures the C matrix dimensions are compatible with the WGMMA instruction size. **Parameters:** * ​mma\_m ([`Int`](/mojo/stdlib/builtin/int/Int)): The M dimension (rows) of a single WGMMA instruction, must be 64. * ​mma\_n ([`Int`](/mojo/stdlib/builtin/int/Int)): The N dimension (columns) of a single WGMMA instruction, must be multiple of 8. * ​C ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the C matrix within a thread block. **Returns:** [`List`](/mojo/stdlib/collections/list/List): `List[Layout]` - A list containing three layouts: 1. proj\_i: Maps linearized indices to row coordinates 2. proj\_j: Maps linearized indices to column coordinates 3. TV\_tile\_to\_idx: Maps thread/vector/tile coordinates to linearized indices
--- ## wgmma_c_thread_layout
`wgmma_c_thread_layout[C: Layout]() -> Layout` Returns the thread layout component for WGMMA C matrix. Generates the first mode of the WGMMA C layout, which maps thread coordinates to linearized indices in the output matrix. **Parameters:** * ​C ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the C matrix. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout mapping thread coordinates to linearized indices.
--- ## wgmma_output_layout
`wgmma_output_layout[mma_n: Int, C: Layout]() -> Layout` Returns the output layout component for WGMMA C matrix. Generates the second mode of the WGMMA C layout, which maps output vector coordinates to linearized indices in the output matrix. **Parameters:** * ​mma\_n ([`Int`](/mojo/stdlib/builtin/int/Int)): The N dimension of the WGMMA instruction. * ​C ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the C matrix. **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout): `Layout` - A layout mapping output vector coordinates to linearized indices.
--- ## PipelineState
`@register_passable(trivial)` `struct PipelineState[num_stages: Int]` Manages state for a multi-stage pipeline with circular buffer semantics. PipelineState provides a mechanism for tracking the current stage in a multi-stage pipeline, particularly useful for double or triple buffering in GPU tensor operations. It maintains an index that cycles through the available stages, a phase bit that toggles when the index wraps around, and a monotonically increasing count. This struct is commonly used with TMA operations to coordinate the use of multiple buffers in a pipeline fashion, allowing for overlapping computation and data transfer. ## Parameters * ​num\_stages ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of stages in the pipeline (e.g., 2 for double buffering, 3 for triple buffering). ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` Initialize a PipelineState with default values. Creates a new PipelineState with index 0, phase 0, and count 0. `__init__(index: Int, phase: Int, count: Int) -> Self` Initialize a PipelineState with specific values. Creates a new PipelineState with the specified index, phase, and count. **Args:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial stage index. * ​phase ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial phase value (0 or 1). * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial count value. ### `index` `index(self) -> UInt32` Get the current stage index. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The current index value, which ranges from 0 to num\_stages-1. ### `phase` `phase(self) -> UInt32` Get the current phase bit. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The current phase value (0 or 1), which toggles when the index wraps around. ### `step` `step(mut self)` Advance the pipeline state to the next stage. Increments the index and count. When the index reaches num\_stages, it wraps around to 0 and toggles the phase bit. This function is used to move to the next buffer in a multi-buffer pipeline, implementing circular buffer semantics. ### `next` `next(mut self) -> Self` Advance the pipeline state to the next stage and return the new state. This function is used to move to the next buffer in a multi-buffer pipeline, implementing circular buffer semantics. **Returns:** `Self`: The new pipeline state after advancing to the next stage. ### `__enter__` `__enter__(var self) -> Self` Enter the context manager. **Returns:** `Self`: The pipeline state instance for use in a `with` statement.
--- ## RaggedTensorMap
`struct RaggedTensorMap[descriptor_rank: Int, //, dtype: DType, descriptor_shape: IndexList[descriptor_rank], remaining_global_dim_rank: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE]` Creates a TMA descriptor that can handle stores with varying lengths. This struct is mainly used for MHA, where sequence lengths may vary between sample. This struct only supports one dimension being ragged. The continous dimension (where stride is 1) cannot be ragged. ## Parameters * ​descriptor\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the descriptor shape (inferred). * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the tensor. * ​descriptor\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the shared memory descriptor. * ​remaining\_global\_dim\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the remaining global tensor dimensions. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode to use for memory access optimization. Swizzling can improve memory access patterns for specific hardware configurations. Defaults to SWIZZLE\_NONE. ## Fields * ​descriptor (`TMADescriptor`): The TMA descriptor that will be used to store the ragged tensor. * ​max\_length (`Int`): The maximum length present in the sequences of the ragged tensor. * ​global\_shape (`IndexList[RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode].global_rank]`): The shape of the global tensor. * ​global\_stride (`IndexList[RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode].global_rank]`): The stride of the global tensor. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = RaggedTensorMap[dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode]` The TensorMapDescriptorArray type. ### `global_rank` `comptime global_rank = (remaining_global_dim_rank + 3)` The rank of the global tensor. ### `ragged_descriptor_shape` `comptime ragged_descriptor_shape = RaggedTensorMap._descriptor_shape[descriptor_rank, dtype, descriptor_shape, remaining_global_dim_rank, swizzle_mode]()` The shape of the descriptor that will tile and load from shared -> global memory. ## Methods ### `__init__` `__init__(out self, ctx: DeviceContext, global_ptr: LegacyUnsafePointer[Scalar[dtype]], max_length: Int, ragged_stride: Int, batch_size: Int, global_last_dim: Int, remaining_global_dims: IndexList[remaining_global_dim_rank], remaining_global_stride: IndexList[remaining_global_dim_rank])` Initializes a TensorMapDescriptorArray with descriptors for all power-of-2 lengths. This constructor creates a complete set of TMA descriptors, one for each power of 2 from 1 up to max\_descriptor\_length. Each descriptor is configured to handle a different first dimension size (1, 2, 4, 8, ..., max\_descriptor\_length) while maintaining the same remaining tile shape specified by desc\_remaining\_tile\_shape. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The device context used to create the TMA descriptors. * ​global\_ptr (`LegacyUnsafePointer`): The source tensor in global memory that will be accessed using the descriptors. * ​max\_length ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum length present in the sequences of the ragged tensor. * ​ragged\_stride ([`Int`](/mojo/stdlib/builtin/int/Int)): The stride of the ragged dimension in the global tensor. * ​batch\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The total number of sequences in the ragged tensor. * ​global\_last\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The last dimension of the global tensor. * ​remaining\_global\_dims ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The dimensions of the remaining global tensor. * ​remaining\_global\_stride ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The stride of the remaining global tensor. **Raises:** If the operation fails. ### `get_type_name` `static get_type_name() -> String` Returns a string representation of the TensorMapDescriptorArray type. **Returns:** `String`: A string containing the type name with all template parameters. ### `get_device_type_name` `static get_device_type_name() -> String` Returns the device type name for this descriptor array. **Returns:** `String`: A string containing the type name with all template parameters. ### `store_ragged_tile` `store_ragged_tile[rank: Int, //, using_max_descriptor_size: Bool = False](self, coordinates: IndexList[rank], preceding_cumulative_length: Int, store_length: Int, mut tile_iterator: LayoutTensorIter[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked])` Stores a ragged tile from shared memory to global memory. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the coordinates. * ​using\_max\_descriptor\_size ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, optimizes the store around the max descriptor size. **Args:** * ​coordinates ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The starting coordinates of all dimensions except the ragged dimension. * ​preceding\_cumulative\_length ([`Int`](/mojo/stdlib/builtin/int/Int)): The cumulative length of the preceding sequences. * ​store\_length ([`Int`](/mojo/stdlib/builtin/int/Int)): The length of the current sequence to be stored. * ​tile\_iterator ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): The iterator over the tile in shared memory. ### `prefetch_descriptor` `prefetch_descriptor(self)` Prefetches the TMA descriptor into cache.
--- ## SharedMemBarrier
`@register_passable(trivial)` `struct SharedMemBarrier` A hardware-accelerated synchronization primitive for GPU shared memory operations. This struct provides a barrier mechanism optimized for coordinating thread execution and memory transfers in GPU kernels, particularly for Tensor Memory Accelerator (TMA) operations. It enables efficient synchronization between threads and memory operations by leveraging hardware-specific barrier instructions. Key features: * Thread synchronization across thread blocks * Memory transfer completion tracking * Hardware-accelerated barrier operations * Support for phased synchronization This barrier is particularly useful for ensuring that shared memory operations complete before dependent computations begin, which is critical for maintaining data consistency in high-performance GPU kernels. ## Fields * ​mbar (`Int64`): Shared memory location used for the barrier state. This field stores an 8-byte aligned shared memory location that maintains the state of the barrier. The memory must be in shared address space to be accessible by all threads in a block. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `init` `init(ref [3] self, num_threads: Int32 = 1)` Initialize the barrier state with the expected number of threads. Sets up the barrier to expect arrivals from the specified number of threads before it can be satisfied. This is essential for coordinating thread synchronization in GPU kernels. **Args:** * ​num\_threads ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Number of threads that must arrive at the barrier before it is satisfied. Defaults to 1. ### `expect_bytes` `expect_bytes(ref [3] self, bytes: Int32)` Configure the barrier to expect a specific number of bytes to be transferred. Used with TMA operations to indicate the expected size of data transfer. The barrier will be satisfied when the specified number of bytes has been transferred, enabling efficient coordination of memory operations. **Args:** * ​bytes ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Number of bytes expected to be transferred. ### `expect_bytes_relaxed` `expect_bytes_relaxed(ref [3] self, bytes: Int32) -> UInt64` Configure the barrier to expect a specific number of bytes to be transferred. Used with TMA operations to indicate the expected size of data transfer. The barrier will be satisfied when the specified number of bytes has been transferred, enabling efficient coordination of memory operations. **Args:** * ​bytes ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Number of bytes expected to be transferred. **Returns:** `UInt64`: The state. ### `arrive_and_expect_bytes` `arrive_and_expect_bytes(ref [3] self, bytes: Int32, cta_id: UInt32, pred: UInt32)` Configure the barrier to expect a specific number to bytes to be transferred at a remote CTA. Used with TMA operations to indicate the expected size of data transfer. The barrier will be satisfied when the specified number of bytes has been transferred at the specified CTA in the cluster. **Args:** * ​bytes ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Number of bytes expected to be transferred. * ​cta\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The CTA ID in a cluster to configure an arrival. * ​pred ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Predication on the arrival configuration instruction. Use UInt32 to match `selp.u32` in ptx. ### `wait` `wait[ticks: Optional[UInt32] = None](ref [3] self, phase: UInt32 = 0)` Wait until the barrier is satisfied. Blocks the calling thread until the barrier is satisfied, either by the expected number of threads arriving or the expected data transfer completing. This method implements an efficient spin-wait mechanism optimized for GPU execution. Note: Minimizes thread divergence during synchronization by using hardware-accelerated barrier instructions. **Parameters:** * ​ticks ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The number of ticks to wait before timing out in nanoseconds. Defaults to None. **Args:** * ​phase ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The phase value to check against. Defaults to 0. ### `wait_acquire` `wait_acquire[scope: Scope](ref [3] self, phase: UInt32 = 0)` Acquire and wait until the barrier is satisfied. Blocks the calling thread until the barrier is satisfied, either by the expected number of threads arriving or the expected data transfer completing. This method implements an efficient spin-wait mechanism optimized for GPU execution. Note: Minimizes thread divergence during synchronization by using hardware-accelerated barrier instructions. **Parameters:** * ​scope ([`Scope`](/mojo/stdlib/gpu/intrinsics/Scope)): The scope of the barrier. **Args:** * ​phase ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The phase value to check against. Defaults to 0. ### `wait_relaxed` `wait_relaxed[scope: Scope](ref [3] self, phase: UInt32 = 0)` Wait until the barrier is satisfied with relaxed ordering. Blocks the calling thread until the barrier is satisfied, either by the expected number of threads arriving or the expected data transfer completing. This method implements an efficient spin-wait mechanism optimized for GPU execution. Note: Minimizes thread divergence during synchronization by using hardware-accelerated barrier instructions. **Parameters:** * ​scope ([`Scope`](/mojo/stdlib/gpu/intrinsics/Scope)): The scope of the barrier. **Args:** * ​phase ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The phase value to check against. Defaults to 0. ### `unsafe_ptr` `unsafe_ptr[mut: Bool, //, origin: Origin[mut]](ref [origin, 3] self) -> LegacyUnsafePointer[Int64, address_space=AddressSpace.SHARED, mut=mut, origin=origin]` Get an unsafe pointer to the barrier's memory location. Provides low-level access to the shared memory location storing the barrier state. This method is primarily used internally by other barrier operations that need direct access to the underlying memory. **Parameters:** * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Mutability of self. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): Origin of self. **Returns:** `LegacyUnsafePointer`: An unsafe pointer to the barrier's memory location in shared memory, properly typed and aligned for barrier operations. ### `arrive_cluster` `arrive_cluster(ref [3] self, cta_id: UInt32, count: UInt32 = 1)` Signal arrival at the barrier from a specific CTA (Cooperative Thread Array) in a cluster. This method is used in multi-CTA scenarios to coordinate barrier arrivals across different CTAs within a cluster. It enables efficient synchronization across thread blocks in clustered execution models. **Args:** * ​cta\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The ID of the CTA (Cooperative Thread Array) that is arriving. * ​count ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The number of arrivals to signal. Defaults to 1. ### `arrive` `arrive(ref [3] self) -> Int` Signal arrival at the barrier and return the arrival count. This method increments the arrival count at the barrier and returns the updated count. It's used to track how many threads have reached the synchronization point. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The updated arrival count after this thread's arrival.
--- ## TMATensorTile
`struct TMATensorTile[dtype: DType, layout: Layout, desc_layout: Layout = layout, is_k_major: Bool = True]` A hardware-accelerated tensor memory access (TMA) tile for efficient asynchronous data movement. The TMATensorTile struct provides a high-performance interface for asynchronous data transfers between global memory and shared memory in GPU tensor operations. It encapsulates a TMA descriptor that defines the memory access pattern and provides methods for various asynchronous operations. Performance: * Hardware-accelerated memory transfers using TMA instructions * Supports prefetching of descriptors for latency hiding * Enforces 128-byte alignment requirements for optimal memory access ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType The data type of the tensor elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout The layout of the tile in shared memory, typically specified as row\_major. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = layout The layout of the descriptor, which can be different from the shared memory layout to accommodate hardware requirements like WGMMA. * ​is\_k\_major ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool = True Whether the shared memory is k-major. ## Fields * ​descriptor (`TMADescriptor`): The TMA descriptor that defines the memory access pattern. This field stores the hardware descriptor that encodes information about: * The source tensor's memory layout and dimensions * The tile shape and access pattern * Swizzling configuration for optimal memory access The descriptor is used by the GPU's Tensor Memory Accelerator hardware to efficiently transfer data between global and shared memory. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = TMATensorTile[dtype, layout, desc_layout, is_k_major]` The device-side type representation. ## Methods ### `__init__` `@implicit` `__init__(out self, descriptor: TMADescriptor)` Initializes a new TMATensorTile with the provided TMA descriptor. **Args:** * ​descriptor ([`TMADescriptor`](/mojo/stdlib/gpu/host/nvidia/tma/TMADescriptor)): The TMA descriptor that defines the memory access pattern. ### `__copyinit__` `__copyinit__(out self, other: Self)` Copy initializes this `TMATensorTile` from another instance. **Args:** * ​other (`Self`): The other `TMATensorTile` instance to copy from. ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. **Returns:** `String`: This type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. **Returns:** `String`: This type's name. ### `prefetch_descriptor` `prefetch_descriptor(self)` Prefetches the TMA descriptor into cache to reduce latency. This method helps hide memory access latency by prefetching the descriptor before it's needed for actual data transfers. ### `async_copy` `async_copy[cta_group: Int = 1](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref [3] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt])` Schedules an asynchronous copy from global memory to shared memory at specified coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory. The transfer is tracked by the provided memory barrier. **Constraints:** * The destination tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Parameters:** * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Int If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The 2D coordinates in the source tensor from which to copy data. `async_copy[rank: Int, //, cta_group: Int = 1](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref [3] mem_barrier: SharedMemBarrier, coords: StaticTuple[UInt32, rank])` Schedules an asynchronous copy from global memory to shared memory for N-dimensional tensors. This is a generic dispatcher that selects the appropriate rank-specific async copy method based on the tensor rank. It provides a unified interface for initiating TMA transfers across 2D, 3D, 4D, and 5D tensors using `StaticTuple` coordinates. **Constraints:** * The rank must be 2, 3, 4, or 5. * The destination tensor must be 128-byte aligned in shared memory. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensionality of the tensor (must be 2, 3, 4, or 5). * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): If set to 2, only the leader CTA needs to be notified upon completion. Defaults to 1. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)): The N-dimensional coordinates in the source tensor from which to copy data, provided as a `StaticTuple` of `UInt32` values. ### `async_copy_3d` `async_copy_3d(self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref [3] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt, UInt])` Schedules an asynchronous copy from global memory to shared memory at specified 3D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 3D tensors. The transfer is tracked by the provided memory barrier. **Constraints:** * The destination tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The 3D coordinates in the source tensor from which to copy data. ### `async_copy_4d` `async_copy_4d[cta_group: Int = 1](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref [3] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt, UInt, UInt])` Schedules an asynchronous copy from global memory to shared memory at specified 4D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 4D tensors. The transfer is tracked by the provided memory barrier. **Constraints:** * The destination tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Parameters:** * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Int If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The 4D coordinates in the source tensor from which to copy data. ### `async_copy_5d` `async_copy_5d(self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref [3] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt, UInt, UInt, UInt])` Schedules an asynchronous copy from global memory to shared memory at specified 5D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to the specified destination in shared memory for 5D tensors. The transfer is tracked by the provided memory barrier. **Constraints:** * The destination tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The 5D coordinates in the source tensor from which to copy data. ### `async_store` `async_store[rank: Int, //, cta_group: Int = 1](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: StaticTuple[UInt32, rank])` Schedules an asynchronous store from shared memory to global memory for N-dimensional tensors. This is a generic dispatcher that selects the appropriate rank-specific async store method based on the tensor rank. It provides a unified interface for initiating TMA store operations across 2D, 3D, 4D, and 5D tensors using `StaticTuple` coordinates. **Constraints:** * The rank must be 2, 3, 4, or 5. * The source tensor must be 128-byte aligned in shared memory. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensionality of the tensor (must be 2, 3, 4, or 5). * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): CTA group configuration for the store operation. Defaults to 1. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory from which data will be copied to global memory. Must be 128-byte aligned. * ​coords ([`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)): The N-dimensional coordinates in the destination global tensor where data will be stored, provided as a `StaticTuple` of `UInt32` values. `async_store(self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Schedules an asynchronous store from shared memory to global memory. This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to global memory at the specified coordinates. **Constraints:** The source tensor must be 128-byte aligned in shared memory. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor The source tensor in shared memory from which data will be copied. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tuple\[UInt, UInt] The 2D coordinates in the destination tensor where data will be stored. ### `async_multicast_load` `async_multicast_load[cta_group: Int = 1](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ref [3] mem_barrier: SharedMemBarrier, coords: Tuple[UInt, UInt], multicast_mask: UInt16)` Schedules an asynchronous multicast load from global memory to multiple shared memory locations. This method initiates a hardware-accelerated asynchronous transfer of data from global memory to multiple destination locations in shared memory across different CTAs (Cooperative Thread Arrays) as specified by the multicast mask. **Constraints:** The destination tensor must be 128-byte aligned in shared memory. **Parameters:** * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Int If the TMA is issued with cta\_group == 2, only the leader CTA needs to be notified upon completion. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): SharedMemBarrierArray The memory barrier used to track and synchronize the asynchronous transfer. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tuple\[UInt, UInt] The 2D coordinates in the source tensor from which to copy data. * ​multicast\_mask ([`UInt16`](/mojo/stdlib/builtin/simd/#uint16)): UInt16 A bit mask specifying which CTAs should receive the data. ### `async_multicast_load_partitioned` `async_multicast_load_partitioned[tma_rows: Int, tma_load_size: Int](self, dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], ref [3] mem_barrier: SharedMemBarrier, rank: UInt, coords: Tuple[UInt, UInt], multicast_mask: UInt16)` Performs a partitioned multicast load where each rank loads a distinct slice of data. This method is designed for clustered execution where different ranks (CTAs) load different, contiguous slices of the source tensor. Each rank's slice is offset by `rank * tma_rows` in the second dimension and stored at offset `rank * tma_load_size` in shared memory. Note: This is typically used in matrix multiplication kernels where the input matrices are partitioned across multiple CTAs for parallel processing. **Parameters:** * ​tma\_rows ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of rows each rank is responsible for loading. * ​tma\_load\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size in elements of each rank's slice in shared memory. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The destination tensor in shared memory where data will be copied. Must be 128-byte aligned. * ​mem\_barrier ([`SharedMemBarrier`](/mojo/kernels/layout/tma_async/SharedMemBarrier)): The memory barrier used to track and synchronize the asynchronous transfer. * ​rank ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The rank ID (0-based) that determines which slice to load. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The base 2D coordinates in the source tensor from which to copy data. The second coordinate will be offset by `rank * tma_rows`. * ​multicast\_mask ([`UInt16`](/mojo/stdlib/builtin/simd/#uint16)): A bit mask specifying which CTAs should receive the data. ### `async_store_3d` `async_store_3d(self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt, UInt])` Schedules an asynchronous store from shared memory to global memory at specified 3D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 3D tensors. **Constraints:** * The source tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory from which data will be copied. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The 3D coordinates in the destination tensor where data will be stored. ### `async_store_4d` `async_store_4d(self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt, UInt, UInt])` Schedules an asynchronous store from shared memory to global memory at specified 4D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 4D tensors. **Constraints:** * The source tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory from which data will be copied. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The 4D coordinates in the destination tensor where data will be stored. ### `async_store_5d` `async_store_5d(self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt, UInt, UInt, UInt])` Schedules an asynchronous store from shared memory to global memory at specified 5D coordinates. This method initiates a hardware-accelerated asynchronous transfer of data from shared memory to the specified destination in global memory for 5D tensors. **Constraints:** * The source tensor must be 128-byte aligned in shared memory. * The descriptor layout may be smaller than the shared memory tile shape to accommodate hardware requirements. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory from which data will be copied. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The 5D coordinates in the destination tensor where data will be stored. ### `async_reduce` `async_reduce[reduction_kind: ReduceOp](self, src: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Schedules an asynchronous reduction operation from shared memory to global memory. This method initiates a hardware-accelerated asynchronous reduction operation that combines data from shared memory with data in global memory using the specified reduction operation. The reduction is performed element-wise at the specified coordinates in the global tensor. **Constraints:** The source tensor must be 128-byte aligned in shared memory. **Parameters:** * ​reduction\_kind ([`ReduceOp`](/mojo/stdlib/gpu/memory/memory/ReduceOp)): The type of reduction operation to perform (e.g., ADD, MIN, MAX). This determines how values are combined during the reduction. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor in shared memory containing the data to be reduced. Must be 128-byte aligned. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The 2D coordinates in the destination tensor where the reduction will be applied. ### `commit_group` `commit_group(self)` Commits all prior initiated but uncommitted TMA instructions into a group. This function behaves the same as `cp_async_bulk_commit_group`, which creates a synchronization point for bulk TMA transfer. ### `wait_group` `wait_group[n: Int = 0](self)` Wait for the completion of asynchronous copy until a specified number of groups are waiting. This function behaves the same as `cp_async_bulk_wait_group`, which causes the executing thread to wait until a specified number of the most recent TMA copy are pending. **Parameters:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of pending groups left. ### `smem_tensormap_init` `smem_tensormap_init(self, smem_tma_descriptor_ptr: LegacyUnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED])` Initializes a TMA descriptor in shared memory from this tensor tile's descriptor. This method copies the TMA descriptor from global memory to shared memory, allowing for faster access during kernel execution. The descriptor is copied in 16-byte chunks using asynchronous copy operations for efficiency. Note: * Only one thread should call this method to avoid race conditions * The descriptor is copied in 8 chunks of 16 bytes each (total 128 bytes) **Args:** * ​smem\_tma\_descriptor\_ptr (`LegacyUnsafePointer`): Pointer to the location in shared memory where the descriptor will be stored. Must be properly aligned. ### `replace_tensormap_global_address_in_gmem` `replace_tensormap_global_address_in_gmem[_dtype: DType](self, src_ptr: LegacyUnsafePointer[Scalar[_dtype]])` Replaces the global memory address in the TMA descriptor stored in global memory. This method allows dynamically changing the source tensor for TMA operations without recreating the entire descriptor, which is useful for reusing descriptors with different data sources. The operation modifies the descriptor in global memory directly. Note: A memory fence may be required after this operation to ensure visibility of the changes to other threads. **Parameters:** * ​\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the new source tensor. **Args:** * ​src\_ptr (`LegacyUnsafePointer`): The new source tensor whose address will replace the current one in the descriptor. Must have compatible layout with the original tensor. ### `tensormap_fence_acquire` `tensormap_fence_acquire(self)` Establishes a memory fence for TMA operations with acquire semantics. This method ensures proper ordering of memory operations by creating a barrier that prevents subsequent TMA operations from executing before prior operations have completed. It is particularly important when reading from a descriptor that might have been modified by other threads or processes. The acquire semantics ensure that all memory operations after this fence will observe any modifications made to the descriptor before the fence. Notes: * The entire warp must call this function as the instruction is warp-aligned. * Typically used in pairs with `tensormap_fence_release` for proper synchronization. ### `tensormap_fence_release` `tensormap_fence_release(self)` Establishes a memory fence for TMA operations with release semantics. This method ensures proper ordering of memory operations by creating a barrier that ensures all prior memory operations are visible before subsequent operations can proceed. It is particularly important when modifying a TMA descriptor in global memory that might be read by other threads or processes. The release semantics ensure that all memory operations before this fence will be visible to any thread that observes operations after the fence. Notes: * Typically used after modifying a tensormap descriptor in global memory. * Often paired with `tensormap_fence_acquire` for proper synchronization. ### `replace_tensormap_global_address_in_shared_mem` `replace_tensormap_global_address_in_shared_mem[_dtype: DType](self, smem_tma_descriptor_ptr: LegacyUnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED, mut=mut, origin=origin], src_ptr: LegacyUnsafePointer[Scalar[_dtype]])` Replaces the global memory address in the TMA descriptor stored in shared memory. This method allows dynamically changing the source tensor for TMA operations without recreating the entire descriptor, which is useful for reusing descriptors with different data sources. The operation modifies a descriptor that has been previously copied to shared memory. Notes: * Only one thread should call this method to avoid race conditions. * A memory fence may be required after this operation to ensure visibility of the changes to other threads. * Typically used with descriptors previously initialized with `smem_tensormap_init`. **Parameters:** * ​\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the new source tensor. **Args:** * ​smem\_tma\_descriptor\_ptr (`LegacyUnsafePointer`): Pointer to the TMA descriptor in shared memory that will be modified. * ​src\_ptr (`LegacyUnsafePointer`): The new source tensor whose address will replace the current one in the descriptor. ### `tensormap_cp_fence_release` `tensormap_cp_fence_release(self, smem_tma_descriptor_ptr: LegacyUnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED])` Establishes a memory fence for TMA operations with release semantics for shared memory descriptors. This method ensures proper ordering of memory operations by creating a barrier that ensures all prior memory operations are visible before subsequent operations can proceed. It is specifically designed for synchronizing between global memory and shared memory TMA descriptors. The release semantics ensure that all memory operations before this fence will be visible to any thread that observes operations after the fence. Notes: * The entire warp must call this function as the instruction is warp-aligned * Typically used after modifying a tensormap descriptor in shared memory * More specialized than the general `tensormap_fence_release` for cross-memory space synchronization **Args:** * ​smem\_tma\_descriptor\_ptr (`LegacyUnsafePointer`): Pointer to the TMA descriptor in shared memory that is being synchronized with the global memory descriptor. ### `replace_tensormap_global_dim_strides_in_shared_mem` `replace_tensormap_global_dim_strides_in_shared_mem[_dtype: DType, only_update_dim_0: Bool, /, *, rank: Int](self, smem_tma_descriptor_ptr: LegacyUnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED, mut=mut, origin=origin], gmem_dims: IndexList[rank], gmem_strides: IndexList[rank])` Replaces dimensions and strides in a TMA descriptor stored in shared memory. Note: This function is only supported for CUDA versions >= 12.5. This function allows dynamically modifying the dimensions and strides of a TMA descriptor that has been previously initialized in shared memory. If only the first dimension (dim 0) is updated, then updating strides can be skipped. Notes: * Only one thread should call this method to avoid race conditions. * A memory fence may be required after this operation to ensure visibility of the changes to other threads. **Parameters:** * ​\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the new source tensor. * ​only\_update\_dim\_0 ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If true, only the first dimension (dim 0) is updated with updating strides. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the tensor. **Args:** * ​smem\_tma\_descriptor\_ptr (`LegacyUnsafePointer`): Pointer to the TMA descriptor in shared memory that will be modified. * ​gmem\_dims ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The global dimensions of the tensor to be updated. * ​gmem\_strides ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The global strides of the tensor to be updated. `replace_tensormap_global_dim_strides_in_shared_mem[_dtype: DType, tensor_rank: Int, dim_idx: Int](self, smem_tma_descriptor_ptr: LegacyUnsafePointer[TMADescriptor, address_space=AddressSpace.SHARED, mut=mut, origin=origin], dim_value: UInt32, dim_stride: Optional[UInt64] = None)` Replaces dimensions and strides in a TMA descriptor stored in shared memory. Note: This function is only supported for CUDA versions >= 12.5. This function allows dynamically modifying the dimensions and strides of a TMA descriptor that has been previously initialized in shared memory. If only the first dimension is updated, then updating strides can be skipped. Notes: * Only one thread should call this method to avoid race conditions. * A memory fence may be required after this operation to ensure visibility of the changes to other threads. **Parameters:** * ​\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the source tensor in GMEM. * ​tensor\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the source tensor in GMEM. * ​dim\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the dimension to be updated in the TMA descriptor with the provided dimension and stride values at runtime. **Args:** * ​smem\_tma\_descriptor\_ptr (`LegacyUnsafePointer`): Pointer to the TMA descriptor in shared memory that will be modified. * ​dim\_value ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The new dimension value to be set. * ​dim\_stride ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The new stride value to be set.
--- ## TMATensorTileArray
`@register_passable(trivial)` `struct TMATensorTileArray[num_of_tensormaps: Int, dtype: DType, cta_tile_layout: Layout, desc_layout: Layout]` An array of TMA descripotr. ## Parameters * ​num\_of\_tensormaps ([`Int`](/mojo/stdlib/builtin/int/Int)): Int The number of TMA descriptors aka tensor map. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType The data type of the tensor elements. * ​cta\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout The layout of the tile in shared memory, typically specified as row\_major. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout The layout of the descriptor, which can be different from the shared memory layout to accommodate hardware requirements like WGMMA. ## Fields * ​tensormaps\_ptr (`LegacyUnsafePointer[UInt8]`): A static tuple of pointers to TMA descriptors. This field stores an array of pointers to `TMATensorTile` instances, where each pointer references a TMA descriptor in device memory. The array has a fixed size determined by the num\_of\_tensormaps parameter. The TMA descriptors are used by the GPU hardware to efficiently transfer data between global and shared memory with specific memory access patterns defined by the layouts. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `descriptor_bytes` `comptime descriptor_bytes = 128` Size of the TMA descriptor in bytes. This is a constant value that represents the size of the TMA descriptor in bytes. It is used to calculate the offset of the TMA descriptor in the device memory. ### `device_type` `comptime device_type = TMATensorTileArray[num_of_tensormaps, dtype, cta_tile_layout, desc_layout]` The device-side type representation. ## Methods ### `__init__` `__init__(tensormaps_device: DeviceBuffer[DType.uint8]) -> Self` Initializes a new TMATensorTileArray. **Args:** * ​tensormaps\_device ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Device buffer to store TMA descriptors. ### `__getitem__` `__getitem__(self, index: Int) -> LegacyUnsafePointer[TMATensorTile[dtype, cta_tile_layout, desc_layout]]` Retrieve a TMA descriptor. **Args:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): Index of the TMA descriptor. **Returns:** `LegacyUnsafePointer`: `UnsafePointer` to the `TMATensorTile` at the specified index. ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. **Returns:** `String`: This type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. **Returns:** `String`: This type's name.
--- ## create_split_tma
`create_split_tma[rank: Int, dtype: DType, //, smem_shape: IndexList[rank], gmem_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle](ctx: DeviceContext, ptr: LegacyUnsafePointer[Scalar[dtype]], runtime_dim0: Int, out res: TMATensorTile[dtype, _split_last_layout[dtype](smem_shape, swizzle_mode, True), _ragged_desc_layout[dtype](smem_shape, swizzle_mode)])` Creates a TMA tensor tile assuming that the first dimension in global memory has `UNKNOWN_VALUE`. This function creates a `TMATensorTile` that optionally splits the last dimension of the tensor into multiples of swizzle granularity. This functionality is currently disabled because it was not found to improve performance. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions of the tensor. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the tensor elements. * ​smem\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the tile in shared memory. * ​gmem\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the global memory tensor. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode for memory access optimization. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The CUDA device context used to create the TMA descriptor. * ​ptr (`LegacyUnsafePointer`): Pointer to the global memory tensor data. * ​runtime\_dim0 ([`Int`](/mojo/stdlib/builtin/int/Int)): The runtime size of the first dimension of the global tensor. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): The resulting TMA tensor tile with split layout. **Raises:** If TMA descriptor creation fails. `create_split_tma[rank: Int, dtype: DType, //, smem_shape: IndexList[rank], gmem_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle](ctx: DeviceContext, ptr: LegacyUnsafePointer[Scalar[dtype]], runtime_dim0: Int, runtime_dim1: Int, out res: TMATensorTile[dtype, _split_last_layout[dtype](smem_shape, swizzle_mode, True), _ragged_desc_layout[dtype](smem_shape, swizzle_mode)])` Creates a TMA tensor tile assuming that the first two dimensions in global memory has `UNKNOWN_VALUE`. This function creates a `TMATensorTile` that optionally splits the last dimension of the tensor into multiples of swizzle granularity. This functionality is currently disabled because it was not found to improve performance. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of dimensions of the tensor. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the tensor elements. * ​smem\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the tile in shared memory. * ​gmem\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the global memory tensor. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode for memory access optimization. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The CUDA device context used to create the TMA descriptor. * ​ptr (`LegacyUnsafePointer`): Pointer to the global memory tensor data. * ​runtime\_dim0 ([`Int`](/mojo/stdlib/builtin/int/Int)): The runtime size of the first dimension of the global tensor. * ​runtime\_dim1 ([`Int`](/mojo/stdlib/builtin/int/Int)): The runtime size of the second dimension of the global tensor. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): The resulting TMA tensor tile with split layout. **Raises:** If TMA descriptor creation fails.
--- ## create_tma_tile
`create_tma_tile[*tile_sizes: Int, *, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](ctx: DeviceContext, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> TMATensorTile[dtype, Layout.row_major(_to_int_tuple[tile_sizes]())]` Creates a `TMATensorTile` with specified tile dimensions and swizzle mode. This function creates a hardware-accelerated Tensor Memory Access (TMA) descriptor for efficient asynchronous data transfers between global memory and shared memory. It configures the tile dimensions and memory access patterns based on the provided parameters. **Constraints:** * The last dimension's size in bytes must not exceed the swizzle mode's byte limit (32B for SWIZZLE\_32B, 64B for SWIZZLE\_64B, 128B for SWIZZLE\_128B). * Only supports 2D tensors in this overload. **Parameters:** * ​\*tile\_sizes ([`Int`](/mojo/stdlib/builtin/int/Int)): The dimensions of the tile to be transferred. For 2D tensors, this should be \[height, width]. The dimensions determine the shape of data transferred in each TMA operation. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode to use for memory access optimization. Swizzling can improve memory access patterns for specific hardware configurations. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The CUDA device context used to create the TMA descriptor. * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The source tensor from which data will be transferred. This defines the global memory layout and data type. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): A `TMATensorTile` configured with the specified tile dimensions and swizzle mode, ready for use in asynchronous data transfer operations. `create_tma_tile[dtype: DType, rank: Int, //, tile_shape: IndexList[rank], /, k_major_tma: Bool = True, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, *, __tile_layout: Layout = Layout.row_major(tile_shape.__getitem__[rank, DType.int64, Int](0), tile_shape.__getitem__[rank, DType.int64, Int](1)), __desc_layout: Layout = _tma_desc_tile_layout[dtype, rank, tile_shape, swizzle_mode]()](ctx: DeviceContext, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> TMATensorTile[dtype, __tile_layout, __desc_layout, k_major_tma]` Creates a `TMATensorTile` with advanced configuration options for 2D, 3D, 4D, or 5D tensors. This overload provides more control over the TMA descriptor creation, allowing specification of data type, rank, and layout orientation. It supports 2D, 3D, 4D, and 5D tensors and provides fine-grained control over the memory access patterns. **Constraints:** * Only supports 2D, 3D, 4D, and 5D tensors (rank must be 2, 3, 4, or 5). * For non-SWIZZLE\_NONE modes, the K dimension size in bytes must be a multiple of the swizzle mode's byte size. * For MN-major layout, only SWIZZLE\_128B is supported. * For 3D, 4D, and 5D tensors, only K-major layout is supported. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType The data type of the tensor elements. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): Int The dimensionality of the tensor (must be 2, 3, 4, or 5). * ​tile\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): IndexList\[rank] The shape of the tile to be transferred. * ​k\_major\_tma ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool = True Whether the tma should copy desc into shared memory following a column-major (if `True`) or row-major (if `False`) pattern. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): TensorMapSwizzle = TensorMapSwizzle.SWIZZLE\_NONE The swizzling mode to use for memory access optimization. * ​\_\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = Layout.row\_major(tile\_shape\[0], tile\_shape\[1]) Internal parameter for the tile layout in shared memory. * ​\_\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = \_tma\_desc\_tile\_layout\[...] Internal parameter for the descriptor layout, which may differ from the tile layout to accommodate hardware requirements. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): DeviceContext The CUDA device context used to create the TMA descriptor. * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor\[dtype, \**, \*\**] The source tensor from which data will be transferred. This defines the global memory layout and must match the specified data type. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): A `TMATensorTile` configured with the specified parameters, ready for use in asynchronous data transfer operations.
--- ## create_tma_tile_template
`create_tma_tile_template[dtype: DType, rank: Int, tile_shape: IndexList[rank], /, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, *, __tile_layout: Layout = Layout.row_major(tile_shape.__getitem__[rank, DType.int64, Int](0), tile_shape.__getitem__[rank, DType.int64, Int](1)), __desc_layout: Layout = _tma_desc_tile_layout[dtype, rank, tile_shape, swizzle_mode]()]() -> TMATensorTile[dtype, __tile_layout, __desc_layout]` Same as create\_tma\_tile expect the descriptor is only a placeholder or a template for later replacement. specification of data type, rank, and layout orientation. It supports both 2D and 3D tensors and provides fine-grained control over the memory access patterns. **Constraints:** * Only supports 2D and 3D tensors (rank must be 2 or 3). * For non-SWIZZLE\_NONE modes, the K dimension size in bytes must be a multiple of the swizzle mode's byte size. * For MN-major layout, only SWIZZLE\_128B is supported. * For 3D tensors, only K-major layout is supported. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType The data type of the tensor elements. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): Int The dimensionality of the tensor (must be 2 or 3). * ​tile\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): IndexList\[rank] The shape of the tile to be transferred. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): TensorMapSwizzle = TensorMapSwizzle.SWIZZLE\_NONE The swizzling mode to use for memory access optimization. * ​\_\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = Layout.row\_major(tile\_shape\[0], tile\_shape\[1]) Internal parameter for the tile layout in shared memory. * ​\_\_desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout = \_tma\_desc\_tile\_layout\[...] Internal parameter for the descriptor layout, which may differ from the tile layout to accommodate hardware requirements. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile): A `TMATensorTile` configured with the specified parameters, ready for use in asynchronous data transfer operations.
--- ## tma_async
Tensor Memory Accelerator (TMA) Asynchronous Operations Module Provides high-performance abstractions for NVIDIA's Tensor Memory Accelerator (TMA), enabling efficient asynchronous data movement between global and shared memory in GPU kernels. It is designed for use with NVIDIA Hopper architecture and newer GPUs that support TMA instructions. ## Key Components: * `TMATensorTile`: Core struct that encapsulates a TMA descriptor for efficient data transfers between global and shared memory with various access patterns and optimizations. * `SharedMemBarrier`: Synchronization primitive for coordinating asynchronous TMA operations, ensuring data transfers complete before dependent operations begin. * `PipelineState`: Helper struct for managing multi-stage pipeline execution with circular buffer semantics, enabling efficient double or triple buffering techniques. * `create_tma_tile`: Factory functions for creating optimized `TMATensorTile` instances with various configurations for different tensor shapes and memory access patterns. ## `comptime` values ### `SplitLastDimTMATensorTile` `comptime SplitLastDimTMATensorTile[rank: Int, //, dtype: DType, smem_shape: IndexList[rank], swizzle_mode: TensorMapSwizzle] = TMATensorTile[dtype, _split_last_layout[dtype](smem_shape, swizzle_mode, True), _ragged_desc_layout[dtype](smem_shape, swizzle_mode)]` A specialized TMA tensor tile type alias that handles layouts where the last dimension is split based on swizzle granularity for optimal memory access patterns. The current behavior is to not actually split the last dimension. #### Parameters * ​rank ([`Int`](/stdlib/builtin/int/Int)): The number of dimensions of the tensor. * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): The data type of the tensor elements. * ​smem\_shape ([`IndexList`](/stdlib/utils/index_/IndexList)): The shape of the tile in shared memory. The last dimension will be padded if necessary to align with the swizzle granularity. * ​swizzle\_mode ([`TensorMapSwizzle`](/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): The swizzling mode for memory access optimization. Determines the granularity at which the last dimension is split or padded. ## Structs * [​`PipelineState`](./PipelineState): Manages state for a multi-stage pipeline with circular buffer semantics. * [​`RaggedTensorMap`](./RaggedTensorMap): Creates a TMA descriptor that can handle stores with varying lengths. This struct is mainly used for MHA, where sequence lengths may vary between sample. * [​`SharedMemBarrier`](./SharedMemBarrier): A hardware-accelerated synchronization primitive for GPU shared memory operations. * [​`TMATensorTile`](./TMATensorTile): A hardware-accelerated tensor memory access (TMA) tile for efficient asynchronous data movement. * [​`TMATensorTileArray`](./TMATensorTileArray): An array of TMA descripotr. ## Functions * [​`create_split_tma`](./create_split_tma): Creates a TMA tensor tile assuming that the first dimension in global memory has `UNKNOWN_VALUE`. * [​`create_tma_tile`](./create_tma_tile): Creates a `TMATensorTile` with specified tile dimensions and swizzle mode. * [​`create_tma_tile_template`](./create_tma_tile_template): Same as create\_tma\_tile expect the descriptor is only a placeholder or a template for later replacement.
--- ## accumulate
--- ## dot_at_b
`dot_at_b(c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## dot_at_b_impl
`dot_at_b_impl(c: LayoutTensor[DType.float32, Layout.row_major(16, 16), MutAnyOrigin], a: LayoutTensor[DType.float32, Layout.row_major(16, 16), ImmutAnyOrigin], b: LayoutTensor[DType.float32, Layout.row_major(16, 16), ImmutAnyOrigin])` `dot_at_b_impl(c: LayoutTensor[DType.float16, Layout.row_major(32, 32), MutAnyOrigin], a: LayoutTensor[DType.float16, Layout.row_major(32, 32), ImmutAnyOrigin], b: LayoutTensor[DType.float16, Layout.row_major(32, 32), ImmutAnyOrigin])`
--- ## extrx
`extrx(gpr: Int)` Extracts a row or moves it to x, result in amx0.
--- ## extry
`extry(gpr: Int)` Extracts a row or moves it to y, result in amx0.
--- ## fma
`fma[mode: StringSlice[StaticConstantOrigin], dtype: DType](z_row_index: Int, x_row_index: Int, y_row_index: Int, clear_z: Bool)`
--- ## fma16
`fma16(gpr: Int)` Float16 matrix multiply and subtract.
--- ## fma32
`fma32(gpr: Int)` Float32 matrix multiply and add.
--- ## fma64
`fma64(gpr: Int)` Float64 matrix multiply and add.
--- ## fms16
`fms16(gpr: Int)` Float16 matrix multiply and add.
--- ## fsm32
`fsm32(gpr: Int)` Float32 matrix multiply and subtract.
--- ## fsm64
`fsm64(gpr: Int)` Float64 matrix multiply and subtract.
--- ## genlut
`genlut(gpr: Int)`
--- ## apple_amx_intrinsics
## Functions * [​`dot_at_b`](./dot_at_b): * [​`dot_at_b_impl`](./dot_at_b_impl): * [​`extrx`](./extrx): Extracts a row or moves it to x, result in amx0. * [​`extry`](./extry): Extracts a row or moves it to y, result in amx0. * [​`fma`](./fma): * [​`fma16`](./fma16): Float16 matrix multiply and subtract. * [​`fma32`](./fma32): Float32 matrix multiply and add. * [​`fma64`](./fma64): Float64 matrix multiply and add. * [​`fms16`](./fms16): Float16 matrix multiply and add. * [​`fsm32`](./fsm32): Float32 matrix multiply and subtract. * [​`fsm64`](./fsm64): Float64 matrix multiply and subtract. * [​`genlut`](./genlut): * [​`ldx`](./ldx): * [​`ldy`](./ldy): * [​`ldz`](./ldz): * [​`ldzi`](./ldzi): * [​`load_z`](./load_z): * [​`mac16`](./mac16): SI16 matrix multiply and add. * [​`matfp`](./matfp): Float16 matrix multiply. * [​`max_int__`](./max_int__): UI16 matrix multiply. * [​`read_x`](./read_x): * [​`read_y`](./read_y): * [​`store_x`](./store_x): * [​`store_y`](./store_y): * [​`store_z`](./store_z): * [​`stx`](./stx): * [​`sty`](./sty): * [​`stz`](./stz): * [​`stzi`](./stzi): * [​`transpose_z_to_x_or_y`](./transpose_z_to_x_or_y): * [​`vec_int__`](./vec_int__): Horizontal ui16 multiply `z0[i] += x0[i] + y0[i]`. * [​`vecfp`](./vecfp): Horizontal float16 multiply `z0[i] += x0[i] + y0[i]`.
--- ## ldx
`ldx(gpr: Int)`
--- ## ldy
`ldy(gpr: Int)`
--- ## ldz
`ldz(gpr: Int)`
--- ## ldzi
`ldzi(gpr: Int)`
--- ## load_z
`load_z[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## mac16
`mac16(gpr: Int)` SI16 matrix multiply and add.
--- ## matfp
`matfp(gpr: Int)` Float16 matrix multiply.
--- ## max_int__
`max_int__(gpr: Int)` UI16 matrix multiply.
--- ## read_x
`read_x[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## read_y
`read_y[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## store_x
`store_x[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## store_y
`store_y[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## store_z
`store_z[row_count: Int, dtype: DType](src: LegacyUnsafePointer[Scalar[dtype]], start_index: Int)`
--- ## stx
`stx(gpr: Int)`
--- ## sty
`sty(gpr: Int)`
--- ## stz
`stz(gpr: Int)`
--- ## stzi
`stzi(gpr: Int)`
--- ## transpose_z_to_x_or_y
`transpose_z_to_x_or_y[destination: StringSlice[StaticConstantOrigin], dtype: DType](z_col_index: Int, xy_row_index: Int, z_row_suboffset: Int)`
--- ## vec_int__
`vec_int__(gpr: Int)` Horizontal ui16 multiply `z0[i] += x0[i] + y0[i]`.
--- ## vecfp
`vecfp(gpr: Int)` Horizontal float16 multiply `z0[i] += x0[i] + y0[i]`.
--- ## cpu
Provides cpu architecture specific utility functions. ## Modules * [​`apple_amx_intrinsics`](./apple_amx_intrinsics/): * [​`neon_intrinsics`](./neon_intrinsics/): * [​`vnni_intrinsics`](./vnni_intrinsics/):
--- ## neon_intrinsics
--- ## dot_i16_to_i32_AVX2
`dot_i16_to_i32_AVX2[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the two words in each int32 element of a and b plus a int32 from src. **Constraints:** Requires AVX2. The size of the output vector must be 4, 8 or 16. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int16 SIMD vector. * ​b ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int16 SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i16_to_i32_x86
`dot_i16_to_i32_x86[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the two words in each int32 element of a and b plus a int32 from src using VNNI or AVX2. **Constraints:** Requires AVX512\_VNNI or AVX2. The size of the output vector must be 4, 8 or 16. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int16 SIMD vector. * ​b ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int16 SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i8_to_i32_AVX2
`dot_i8_to_i32_AVX2[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the four bytes in each int32 element of a and b plus a int32 from src. **Constraints:** Requires AVX2. The size of the output vector must be 4, 8 or 16. The a argument has range \[0,255]. The b argument has range \[-128,127]. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A uint8 SIMD vector. * ​b ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int8 SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i8_to_i32_saturated_AVX2
`dot_i8_to_i32_saturated_AVX2[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the four bytes in each int32 element of a and b plus a int32 from src. **Constraints:** Requires AVX2. The size of the output vector must be 4, 8 or 16. The a argument has range \[0,127] not \[0, 255]. The b argument has range \[-128,127]. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A uint8 SIMD vector. * ​b ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int8 SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i8_to_i32_saturated_x86
`dot_i8_to_i32_saturated_x86[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the four bytes in each int32 element of a and b plus a int32 from src using VNNI or AVX2. **Constraints:** Requires AVX512\_VNNI or AVX2. The size of the output vector must be 4, 8 or 16. The a argument has range \[0,127] not \[0, 255]. The b argument has range \[-128,127]. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A uint8 SIMD vector. * ​b ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int8 SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## dot_i8_to_i32_x86
`dot_i8_to_i32_x86[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` The dot product of the four bytes in each int32 element of a and b plus a int32 from src using VNNI or AVX2. **Constraints:** Requires AVX512\_VNNI or AVX2. The size of the output vector must be 4, 8 or 16. The a argument has range \[0,255]. The b argument has range \[-128,127]. **Parameters:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of the output SIMD vector. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for a. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for b. * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType for c. **Args:** * ​src ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int32 SIMD vector. * ​a ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A uint8 SIMD vector. * ​b ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A int8 SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector of width elements.
--- ## vnni_intrinsics
## Functions * [​`dot_i16_to_i32_AVX2`](./dot_i16_to_i32_AVX2): The dot product of the two words in each int32 element of a and b plus a int32 from src. * [​`dot_i16_to_i32_x86`](./dot_i16_to_i32_x86): The dot product of the two words in each int32 element of a and b plus a int32 from src using VNNI or AVX2. * [​`dot_i8_to_i32_AVX2`](./dot_i8_to_i32_AVX2): The dot product of the four bytes in each int32 element of a and b plus a int32 from src. * [​`dot_i8_to_i32_saturated_AVX2`](./dot_i8_to_i32_saturated_AVX2): The dot product of the four bytes in each int32 element of a and b plus a int32 from src. * [​`dot_i8_to_i32_saturated_x86`](./dot_i8_to_i32_saturated_x86): The dot product of the four bytes in each int32 element of a and b plus a int32 from src using VNNI or AVX2. * [​`dot_i8_to_i32_x86`](./dot_i8_to_i32_x86): The dot product of the four bytes in each int32 element of a and b plus a int32 from src using VNNI or AVX2. * [​`pmaddubs`](./pmaddubs): * [​`pmaddw`](./pmaddw): * [​`vpdpbusd`](./vpdpbusd): * [​`vpdpbusds`](./vpdpbusds): * [​`vpdpwssd`](./vpdpwssd): * [​`vpdpwssds`](./vpdpwssds):
--- ## pmaddubs
`pmaddubs[width: Int](a: SIMD[DType.int32, width], b: SIMD[DType.int32, width]) -> SIMD[DType.int32, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## pmaddw
`pmaddw[width: Int](a: SIMD[DType.int32, width], b: SIMD[DType.int32, width]) -> SIMD[DType.int32, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## vpdpbusd
`vpdpbusd[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## vpdpbusds
`vpdpbusds[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## vpdpwssd
`vpdpwssd[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## vpdpwssds
`vpdpwssds[width: Int, a_type: DType, b_type: DType, c_type: DType](src: SIMD[c_type, width], a: SIMD[a_type, width], b: SIMD[b_type, width]) -> SIMD[c_type, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## arch
Provides architecture specific utility functions. ## Packages * [​`cpu`](./cpu/): Provides cpu architecture specific utility functions. * [​`sm100`](./sm100/): Provides Nvidia Blackwell architecture specific utility functions.
--- ## sm100
Provides Nvidia Blackwell architecture specific utility functions. ## Modules * [​`mma`](./mma/):
--- ## Major
`@register_passable(trivial)` `struct Major` ## Fields * ​val (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `K` `comptime K = Major(0)` ### `MN` `comptime MN = Major(1)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## MmaOpSM100_BlockScaled_SS
`@register_passable(trivial)` `struct MmaOpSM100_BlockScaled_SS[c_type: DType, a_type: DType, b_type: DType, sfa_dtype: DType, sfb_dtype: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], /, *, accum_type: DType = DType.float32, cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = False]` ## Fields * ​idesc (`UMMAInsDescriptor[MmaOpSM100_BlockScaled_SS._get_umma_kind[c_type, a_type, b_type, sfa_dtype, sfb_dtype, block_tile_shape, mma_shape, accum_type, cta_group, cluster_shape, a_swizzle, b_swizzle, transpose_b, a_type]()]`): * ​mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` ### `mma` `mma(self, a: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_tmem: UInt32, sfa_tmem: UInt32, sfb_tmem: UInt32, init_c: Bool)` MMA input tiles. The layout assumes that coalesce(A) has shape (bm, sw\_k, num\_sw\_k), we currently assumes bm = mma\_m. In future, we can tile it to (mma\_m, sw\_k, num\_sw\_k, num\_mma\_m) The same logic applies to matrix B. ### `commit` `commit(self, ptr_mbar: LegacyUnsafePointer[type, address_space=AddressSpace.SHARED, mut=mut, origin=origin])` ### `wait` `wait(self)`
--- ## MmaOpSM100_SS
`@register_passable(trivial)` `struct MmaOpSM100_SS[c_type: DType, a_type: DType, b_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], /, *, accum_type: DType = DType.float32, cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = False]` ## Fields * ​idesc (`UMMAInsDescriptor[MmaOpSM100_SS._get_umma_kind[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type, cta_group, cluster_shape, a_swizzle, b_swizzle, transpose_b, a_type]()]`): * ​mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` ### `mma` `mma(self, a: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_tmem: UInt32, init_c: Bool)` MMA input tiles. The layout assumes that coalesce(A) has shape (bm, sw\_k, num\_sw\_k), we currently assumes bm = mma\_m. In future, we can tile it to (mma\_m, sw\_k, num\_sw\_k, num\_mma\_m) The same logic applies to matrix B. ### `commit` `commit(self, ptr_mbar: LegacyUnsafePointer[type, address_space=AddressSpace.SHARED, mut=mut, origin=origin])` ### `wait` `wait(self)`
--- ## extract_first_2_modes
`extract_first_2_modes[l: Layout]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## mma
## Structs * [​`Major`](./Major): * [​`MmaOpSM100_BlockScaled_SS`](./MmaOpSM100_BlockScaled_SS): * [​`MmaOpSM100_SS`](./MmaOpSM100_SS): ## Functions * [​`extract_first_2_modes`](./extract_first_2_modes): * [​`max_contiguous_tile_shape`](./max_contiguous_tile_shape): Returns the maximum shape of a tile that's contiguous in memory for mma op. This is used to create TMA descriptor. * [​`smem_descriptor`](./smem_descriptor):
--- ## max_contiguous_tile_shape
`max_contiguous_tile_shape[rank: Int, //, dtype: DType, tile_shape: IndexList[rank], /, *, major: Major = Major.K, swizzle_mode: SwizzleMode = SwizzleMode.NONE]() -> IntTuple` Returns the maximum shape of a tile that's contiguous in memory for mma op. This is used to create TMA descriptor. **Returns:** `IntTuple`
--- ## smem_descriptor
`smem_descriptor[dtype: DType, //, *, BMN: Int, BK: Int, swizzle_mode: TensorMapSwizzle, is_k_major: Bool](ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin]) -> MMASmemDescriptorPair` **Returns:** `MMASmemDescriptorPair`
--- ## batched_matmul
`batched_matmul[rank: Int, a_type: DType, b_type: DType, c_type: DType, //, *, transpose_a: Bool, transpose_b: Bool, elementwise_epilogue_fn: OptionalReg[fn[c_type: DType, width: Int, rank: Int, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c_buf: NDBuffer[c_type, rank, origin, shape, strides], a_buf: NDBuffer[a_type, rank, origin, shape, strides], b_buf: NDBuffer[b_type, rank, origin, shape, strides], *, context: DeviceContextPtr = DeviceContextPtr())` `batched_matmul[rank: Int, a_type: DType, b_type: DType, c_type: DType, //, *, transpose_b: Bool, elementwise_epilogue_fn: OptionalReg[fn[c_type: DType, width: Int, rank: Int, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None] = None, saturated_vnni: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c_buf: NDBuffer[c_type, rank, origin, shape, strides], a_buf: NDBuffer[a_type, rank, origin, shape, strides], b_buf: NDBuffer[b_type, rank, origin, shape, strides], *, context: DeviceContextPtr = DeviceContextPtr())`
--- ## batched_matmul_dynamic_scaled_fp8
`batched_matmul_dynamic_scaled_fp8[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, input_scale_granularity: StringSlice[StaticConstantOrigin], weight_scale_granularity: StringSlice[StaticConstantOrigin], m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, transpose_b: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[c_type, 3, origin, shape], a: NDBuffer[a_type, 3, origin, shape], b: NDBuffer[b_type, 3, origin, shape], a_scales: NDBuffer[a_scales_type, 3, origin, shape], b_scales: NDBuffer[b_scales_type, 3, origin, shape], ctx: DeviceContext)`
--- ## batched_matmul_dynamic_scaled_fp8_naive
`batched_matmul_dynamic_scaled_fp8_naive[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, scales_granularity_mnk: IndexList[3], transpose_b: Bool = False](c_device: NDBuffer[c_type, 3, origin, shape], a_device: NDBuffer[a_type, 3, origin, shape], b_device: NDBuffer[b_type, 3, origin, shape], a_scales_device: NDBuffer[a_scales_type, 3, origin, shape], b_scales_device: NDBuffer[b_scales_type, 3, origin, shape], ctx: DeviceContext)`
--- ## batched_matmul_kernel_gpu
`batched_matmul_kernel_gpu[c_type: DType, a_type: DType, b_type: DType, layout_c: Layout, layout_a: Layout, layout_b: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[c_type: DType, width: Int, rank: Int, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None] = None](c_tensor: LayoutTensor[c_type, layout_c, MutAnyOrigin], a_tensor: LayoutTensor[a_type, layout_a, MutAnyOrigin], b_tensor: LayoutTensor[b_type, layout_b, MutAnyOrigin], m: Int, n: Int, k: Int)`
--- ## batched_matmul_shape
`batched_matmul_shape[rank: Int, a_type: DType, b_type: DType, single_thread_blocking_override: Bool](a_buff: NDBuffer[a_type, rank, origin], b_buff: NDBuffer[b_type, rank, origin]) -> IndexList[rank]` Compute the output shape of a `batch_matmul` operation, and assert the inputs are compatible. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): Rank of the input and output tensors. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the lhs input tensor. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the rhs input tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​a\_buff ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): The lhs input tensor. * ​b\_buff ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): The rhs input tensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## bmm_sm100_blockwise_scaled_fp8
`bmm_sm100_blockwise_scaled_fp8[a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, *, transpose_b: Bool, umma_shape: IndexList[3], block_tile_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: OptionalReg[fn[c_type: DType, width: Int, rank: Int, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## get_shape_index_list
`get_shape_index_list[rank: Int, dtype: DType, layout: Layout](tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## bmm
## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[c_type: DType, width: Int, rank: Int, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None` ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Functions * [​`batched_matmul`](./batched_matmul): * [​`batched_matmul_dynamic_scaled_fp8`](./batched_matmul_dynamic_scaled_fp8): * [​`batched_matmul_dynamic_scaled_fp8_naive`](./batched_matmul_dynamic_scaled_fp8_naive): * [​`batched_matmul_kernel_gpu`](./batched_matmul_kernel_gpu): * [​`batched_matmul_shape`](./batched_matmul_shape): Compute the output shape of a `batch_matmul` operation, and assert the inputs are compatible. * [​`bmm_sm100_blockwise_scaled_fp8`](./bmm_sm100_blockwise_scaled_fp8): * [​`get_shape_index_list`](./get_shape_index_list): * [​`naive_batched_matmul_kernel`](./naive_batched_matmul_kernel):
--- ## naive_batched_matmul_kernel
`naive_batched_matmul_kernel[rank: Int, c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, elementwise_lambda_fn: OptionalReg[fn[c_type: DType, width: Int, rank: Int, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None] = None, accum_type: DType = get_accum_type[c_type]()](c_tensor: LayoutTensor[c_type, c_layout, MutAnyOrigin], a_tensor: LayoutTensor[a_type, a_layout, MutAnyOrigin], b_tensor: LayoutTensor[b_type, b_layout, MutAnyOrigin], c_buff_nd_shape: IndexList[rank])`
--- ## distributed_matmul
## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[input_index: Int, dtype: DType, rank: Int, width: Int, *, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None` ## Functions * [​`matmul_allreduce`](./matmul_allreduce): Performs C = matmul(A, B^T) followed with Out = allreduce(C) operation across multiple GPUs. Split the A or B and C matrices into `num_partitions` submatrices at dimension `partition_dim`. This way we can perform `num_partitions` independent matmul + allreduce kernels, and overlap some of the computation.
--- ## matmul_allreduce
`matmul_allreduce[ngpus: Int, partition_dim: Int, outputs_lambda: elementwise_epilogue_type, a_dtype: DType, b_dtype: DType, out_dtype: DType, a_static_shape: DimList, b_static_shape: DimList, c_static_shape: DimList, out_static_shape: DimList, overlap_with_dpl: Bool = True](a_buffers: InlineArray[NDBuffer[a_dtype, 2, MutAnyOrigin, a_static_shape], ngpus], b_buffers: InlineArray[NDBuffer[b_dtype, 2, MutAnyOrigin, b_static_shape], ngpus], c_temp_buffers: InlineArray[NDBuffer[out_dtype, 2, MutAnyOrigin, c_static_shape], ngpus], output_buffers: InlineArray[NDBuffer[out_dtype, 2, MutAnyOrigin, out_static_shape], ngpus], rank_sigs: InlineArray[LegacyUnsafePointer[Signal], 8], ctxs: List[DeviceContext], num_partitions: ValOrDim[dim])` Performs C = matmul(A, B^T) followed with Out = allreduce(C) operation across multiple GPUs. Split the A or B and C matrices into `num_partitions` submatrices at dimension `partition_dim`. This way we can perform `num_partitions` independent matmul + allreduce kernels, and overlap some of the computation. `matmul_allreduce[ngpus: Int, outputs_lambda: elementwise_epilogue_type, a_dtype: DType, b_dtype: DType, out_dtype: DType, a_static_shape: DimList, b_static_shape: DimList, c_static_shape: DimList, out_static_shape: DimList](a_buffers: InlineArray[NDBuffer[a_dtype, 2, MutAnyOrigin, a_static_shape], ngpus], b_buffers: InlineArray[NDBuffer[b_dtype, 2, MutAnyOrigin, b_static_shape], ngpus], c_temp_buffers: InlineArray[NDBuffer[out_dtype, 2, MutAnyOrigin, c_static_shape], ngpus], output_buffers: InlineArray[NDBuffer[out_dtype, 2, MutAnyOrigin, out_static_shape], ngpus], rank_sigs: InlineArray[LegacyUnsafePointer[Signal], 8], ctxs: List[DeviceContext])` Performs C = matmul(A, B^T) followed with Out = allreduce(C) operation across multiple GPUs. The implementation might potentially split A / B / C matrices and overlap computation to speedup performance.
--- ## config_in_smem
`config_in_smem[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, //, max_smem: Int](config: MatmulConfig[a_type, b_type, c_type, transpose_b]) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## dual_gemm
`dual_gemm[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool, binary_lambda_fn: binary_fn_type = swilu, config: OptionalReg[MatmulConfig[a_type, b_type, c_type, transpose_b]] = None, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], ctx: DeviceContext)`
--- ## dual_gemv
`dual_gemv[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, binary_lambda_fn: binary_fn_type = swilu, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], ctx: DeviceContext)`
--- ## dual_gemv_kernel
`dual_gemv_kernel[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, simd_width: UInt, tile_m: UInt, tile_n: UInt, num_threads: UInt, binary_lambda_fn: binary_fn_type, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[c_type]()](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutAnyOrigin, b_shape])`
--- ## dual_gemm (Dual_gemm)
## `comptime` values ### `binary_fn_type` `comptime binary_fn_type = fn[type: DType, width: Int](SIMD[type, width], SIMD[type, width]) -> SIMD[type, width]` ## Functions * [​`config_in_smem`](./config_in_smem): * [​`dual_gemm`](./dual_gemm): * [​`dual_gemv`](./dual_gemv): * [​`dual_gemv_kernel`](./dual_gemv_kernel): * [​`multistage_dual_gemm`](./multistage_dual_gemm): * [​`multistage_dual_gemm_kernel`](./multistage_dual_gemm_kernel): * [​`multistage_dual_mma`](./multistage_dual_mma): * [​`swilu`](./swilu): * [​`swishGLU`](./swishGLU): Reference: GLU Variants Improve Transformer by Noam Shazeer The implementation follows cutlass, using one kernel invocation and writing to the destination once.
--- ## multistage_dual_gemm
`multistage_dual_gemm[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, //, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], binary_lambda_fn: binary_fn_type = swilu, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, origin], a: LayoutTensor[a_type, a_layout, origin], b0: LayoutTensor[b_type, b_layout, origin], b1: LayoutTensor[b_type, b_layout, origin], ctx: DeviceContext)` `multistage_dual_gemm[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], binary_lambda_fn: binary_fn_type = swilu, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, num_k_partitions: Int = 1](c: NDBuffer[c_type, 2, origin, c_shape], a: NDBuffer[a_type, 2, origin, a_shape], b0: NDBuffer[b_type, 2, origin, b_shape], b1: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)`
--- ## multistage_dual_gemm_kernel
`multistage_dual_gemm_kernel[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], binary_lambda_fn: binary_fn_type, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b0: LayoutTensor[b_type, b_layout, MutAnyOrigin], b1: LayoutTensor[b_type, b_layout, MutAnyOrigin])`
--- ## multistage_dual_mma
`multistage_dual_mma[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, a_smem_layout: Layout, b_type: DType, b_layout: Layout, b_smem_layout: Layout, //, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_threads: Int, num_pipeline_stages: Int, transpose_b: Bool, /, *, swizzle_a: Bool = True, static_num_iters: Dim = Dim(), k_group_size: UInt = 1](c0: LayoutTensor[c_type, c_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c1: LayoutTensor[c_type, c_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_iter_arg: LayoutTensorIter[dtype, a_layout, MutAnyOrigin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], b0_iter_arg: LayoutTensorIter[b_type, b_layout, MutAnyOrigin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], b1_iter_arg: LayoutTensorIter[b_type, b_layout, MutAnyOrigin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], a_smem_iter_arg: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], mut b0_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], mut b1_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], num_iters: Int, /, *, num_b_rows: OptionalReg[Int] = None)`
--- ## swilu
`swilu[dtype: DType, width: Int](x: SIMD[dtype, width], y: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## swishGLU
`swishGLU[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, target: StringSlice[StaticConstantOrigin] = "cpu"](a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b0: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], b1: NDBuffer[b_type, 2, MutAnyOrigin, b_shape], c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], ctx: DeviceContextPtr)` Reference: GLU Variants Improve Transformer by Noam Shazeer The implementation follows cutlass, using one kernel invocation and writing to the destination once.
--- ## block_scaled_matmul
`block_scaled_matmul[c_type: DType, a_type: DType, b_type: DType, scales_dtype: DType, //, *, SF_VECTOR_SIZE: Int, transpose_b: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu", elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None](c_device: NDBuffer[c_type, 2, MutAnyOrigin, shape], a_device: NDBuffer[a_type, 2, MutAnyOrigin, shape], b_device: NDBuffer[b_type, 2, MutAnyOrigin, shape], a_scales_device: NDBuffer[scales_dtype, 5, MutAnyOrigin, shape], b_scales_device: NDBuffer[scales_dtype, 5, MutAnyOrigin, shape], tensor_sf: Float32, ctx: DeviceContext)`
--- ## block_scales_interleave
`block_scales_interleave[scales_dtype: DType, //, *, SF_VECTOR_SIZE: Int, target: StringSlice[StaticConstantOrigin] = "cpu"](output_scales_device: NDBuffer[scales_dtype, 5, MutAnyOrigin, shape], input_scales_device: NDBuffer[scales_dtype, 2, MutAnyOrigin, shape], ctx: DeviceContext)`
--- ## block_scales_interleave_fp4
`block_scales_interleave_fp4[scales_dtype: DType, input_scales_layout: Layout, output_scales_layout: Layout, //, *, SF_VECTOR_SIZE: Int = 16, num_max_threads: Int = 1024](ctx: DeviceContext, input_scales: LayoutTensor[scales_dtype, input_scales_layout, MutAnyOrigin], output_scales: LayoutTensor[scales_dtype, output_scales_layout, MutAnyOrigin])`
--- ## block_scales_interleave_fp4_kernel
`block_scales_interleave_fp4_kernel[scales_dtype: DType, input_scales_layout: Layout, output_scales_layout: Layout, *, SF_VECTOR_SIZE: Int = 16, num_max_threads: Int = 1024](input_scales: LayoutTensor[scales_dtype, input_scales_layout, MutAnyOrigin], output_scales: LayoutTensor[scales_dtype, output_scales_layout, MutAnyOrigin])`
--- ## fp4_quantization
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Functions * [​`block_scaled_matmul`](./block_scaled_matmul): * [​`block_scales_interleave`](./block_scales_interleave): * [​`block_scales_interleave_fp4`](./block_scales_interleave_fp4): * [​`block_scales_interleave_fp4_kernel`](./block_scales_interleave_fp4_kernel): * [​`naive_block_scaled_nvfp4_matmul`](./naive_block_scaled_nvfp4_matmul): * [​`naive_block_scaled_nvfp4_matmul_kernel`](./naive_block_scaled_nvfp4_matmul_kernel): * [​`quantize_dynamic_block_scaled`](./quantize_dynamic_block_scaled): * [​`quantize_dynamic_scaled_fp4`](./quantize_dynamic_scaled_fp4): * [​`quantize_dynamic_scaled_fp4_kernel`](./quantize_dynamic_scaled_fp4_kernel):
--- ## naive_block_scaled_nvfp4_matmul
`naive_block_scaled_nvfp4_matmul[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, SF_VECTOR_SIZE: Int, accum_type: DType = DType.float32, transpose_b: Bool = True, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, BLOCK_DIM: Int = 16](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## naive_block_scaled_nvfp4_matmul_kernel
`naive_block_scaled_nvfp4_matmul_kernel[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, accum_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scale_layout: Layout, b_scale_layout: Layout, SF_VECTOR_SIZE: Int, transpose_b: Bool = True, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scale_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scale_layout, MutAnyOrigin])`
--- ## quantize_dynamic_block_scaled
`quantize_dynamic_block_scaled[out_dtype: DType, scales_dtype: DType, in_dtype: DType, //, *, SF_VECTOR_SIZE: Int, target: StringSlice[StaticConstantOrigin] = "cpu"](output_device: NDBuffer[out_dtype, 2, MutAnyOrigin, shape], scales_device: NDBuffer[scales_dtype, 5, MutAnyOrigin, shape], input_device: NDBuffer[in_dtype, 2, MutAnyOrigin, shape], tensor_sf: Float32, ctx: DeviceContext)`
--- ## quantize_dynamic_scaled_fp4
`quantize_dynamic_scaled_fp4[out_dtype: DType, scales_dtype: DType, in_dtype: DType, output_layout: Layout, scales_layout: Layout, input_layout: Layout, //, *, SF_VECTOR_SIZE: Int = 16, num_max_threads: Int = 512](ctx: DeviceContext, output: LayoutTensor[out_dtype, output_layout, MutAnyOrigin], scales: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], input: LayoutTensor[in_dtype, input_layout, MutAnyOrigin], num_cols: Int, num_cols_padded: Int, tensor_sf: Float32 = 1)`
--- ## quantize_dynamic_scaled_fp4_kernel
`quantize_dynamic_scaled_fp4_kernel[out_dtype: DType, scales_dtype: DType, in_dtype: DType, output_layout: Layout, scales_layout: Layout, input_layout: Layout, *, SF_VECTOR_SIZE: Int = 16, ELEMENTS_PER_THREAD: Int = 8, num_max_threads: Int = 512](output: LayoutTensor[out_dtype, output_layout, MutAnyOrigin], scales: LayoutTensor[scales_dtype, scales_layout, MutAnyOrigin], input: LayoutTensor[in_dtype, input_layout, MutAnyOrigin], num_cols: Int, num_cols_padded: Int, tensor_sf: Float32)`
--- ## cast_f4e2m1x2_to_fp16x2
`cast_f4e2m1x2_to_fp16x2(x: UInt8) -> SIMD[DType.float16, 2]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## cast_fp32_to_fp4e2m1
`cast_fp32_to_fp4e2m1[width: Int, //](x: SIMD[DType.float32, width]) -> UInt32` **Returns:** `UInt32`
--- ## cast_fp_to_fp4e2m1
`cast_fp_to_fp4e2m1[dtype: DType, width: Int, //](x: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## cast_uint_to_fp4e2m1
`cast_uint_to_fp4e2m1[in_dtype: DType, in_width: Int, //, *, out_dtype: DType, out_width: Int](x: SIMD[in_dtype, in_width]) -> SIMD[out_dtype, out_width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## convert_ref_scales_to_mxfp8_format
`convert_ref_scales_to_mxfp8_format[ref_scales_type: DType, scales_type: DType, *, REF_BLOCK_SIZE: Int, SF_VECTOR_SIZE: Int](m: ValOrDim[dim], n: ValOrDim[dim], k: ValOrDim[dim], ref_a_scales: NDBuffer[ref_scales_type, 2, origin, shape, strides], ref_b_scales: NDBuffer[ref_scales_type, 2, origin, shape, strides], a_scales: NDBuffer[scales_type, 5, origin, shape, strides], b_scales: NDBuffer[scales_type, 5, origin, shape, strides])`
--- ## fp4_utils
## `comptime` values ### `E2M1_TO_FLOAT32` `comptime E2M1_TO_FLOAT32 = SIMD[DType.float32, 16](0, 0.5, 1, 1.5, 2, 3, 4, 6, -0.0, -0.5, -1, -1.5, -2, -3, -4, -6, Tuple[]())` ### `MXFP4_SF_DTYPE` `comptime MXFP4_SF_DTYPE = DType.float8_e8m0fnu` ### `MXFP4_SF_VECTOR_SIZE` `comptime MXFP4_SF_VECTOR_SIZE = 32` ### `MXFP8_SF_DTYPE` `comptime MXFP8_SF_DTYPE = DType.float8_e8m0fnu` ### `MXFP8_SF_VECTOR_SIZE` `comptime MXFP8_SF_VECTOR_SIZE = 32` ### `NVFP4_SF_DTYPE` `comptime NVFP4_SF_DTYPE = DType.float8_e4m3fn` ### `NVFP4_SF_VECTOR_SIZE` `comptime NVFP4_SF_VECTOR_SIZE = 16` ### `SF_ATOM_K` `comptime SF_ATOM_K = 4` ### `SF_ATOM_M` `comptime SF_ATOM_M = Tuple[Int, Int](VariadicPack[True, True, origin_of(), Movable, Int, Int](32, 4))` ### `SF_MN_GROUP_SIZE` `comptime SF_MN_GROUP_SIZE = ((load_from_mem SF_ATOM_M.__getitem__[Int, Int, 0]()) * (load_from_mem SF_ATOM_M.__getitem__[Int, Int, 1]()))` ## Functions * [​`cast_f4e2m1x2_to_fp16x2`](./cast_f4e2m1x2_to_fp16x2): * [​`cast_fp32_to_fp4e2m1`](./cast_fp32_to_fp4e2m1): * [​`cast_fp_to_fp4e2m1`](./cast_fp_to_fp4e2m1): * [​`cast_uint_to_fp4e2m1`](./cast_uint_to_fp4e2m1): * [​`convert_ref_scales_to_mxfp8_format`](./convert_ref_scales_to_mxfp8_format):
--- ## batched_quantize_dynamic_scaled_fp8
`batched_quantize_dynamic_scaled_fp8[out_dtype: DType, in_dtype: DType, scales_dtype: DType, //, group_size_or_per_token: Int](scaled_output: NDBuffer[out_dtype, 3, MutAnyOrigin], scales: NDBuffer[scales_dtype, 3, MutAnyOrigin], input: NDBuffer[in_dtype, 3, origin, shape, strides], scale_ub: Float32, ctx: DeviceContext)`
--- ## batched_quantize_fp8_kernel
`batched_quantize_fp8_kernel[out_type: DType, scales_type: DType, in_type: DType, warps_per_block: Int, group_size: Int](output: NDBuffer[out_type, 3, MutAnyOrigin], scales: NDBuffer[scales_type, 3, MutAnyOrigin], input: NDBuffer[in_type, 3, MutAnyOrigin], scale_ub: Scalar[scales_type])`
--- ## blockwise_scaled_fp8_with_epilogue
`blockwise_scaled_fp8_with_epilogue[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, scales_granularity_mnk: IndexList[3], BLOCK_DIM: Int = 16, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, accum_type: DType = get_accum_type[c_type]()](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)` Our sm100 blockwise scaled fp8 matmul kernel still does not support fusion of elementwise operations. This is a temporary implementation that uses our sm100 blockwise scaled fp8 matmul kernel and dispatch a separate epilogue kernel to apply the elementwise operations. For non B200 GPUs, we use the naive blockwise scaled fp8 matmul which support normal epilogue natively.
--- ## convert_e4m3fn_to_e4m3fnuz
`convert_e4m3fn_to_e4m3fnuz(input_buffer: LayoutTensor[DType.float8_e4m3fn, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_buffer: LayoutTensor[DType.float8_e4m3fnuz, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContext)` Convert E4M3FN weights to E4M3FNUZ format for AMD GPU compatibility. This conversion handles the key differences between E4M3FN and E4M3FNUZ: 1. The bit pattern 10000000 (-128) represents zero in E4M3FN but NaN in E4M3FNUZ **Args:** * ​input\_buffer ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input tensor in E4M3FN format. * ​output\_buffer ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor to store E4M3FNUZ format. * ​context ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): Device context for kernel execution.
--- ## fp8_quantization
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Functions * [​`batched_quantize_dynamic_scaled_fp8`](./batched_quantize_dynamic_scaled_fp8): * [​`batched_quantize_fp8_kernel`](./batched_quantize_fp8_kernel): * [​`blockwise_scaled_fp8_with_epilogue`](./blockwise_scaled_fp8_with_epilogue): Our sm100 blockwise scaled fp8 matmul kernel still does not support fusion of elementwise operations. This is a temporary implementation that uses our sm100 blockwise scaled fp8 matmul kernel and dispatch a separate epilogue kernel to apply the elementwise operations. For non B200 GPUs, we use the naive blockwise scaled fp8 matmul which support normal epilogue natively. * [​`convert_e4m3fn_to_e4m3fnuz`](./convert_e4m3fn_to_e4m3fnuz): Convert E4M3FN weights to E4M3FNUZ format for AMD GPU compatibility. * [​`matmul_dynamic_scaled_fp8`](./matmul_dynamic_scaled_fp8): * [​`naive_blockwise_scaled_fp8_grouped_matmul`](./naive_blockwise_scaled_fp8_grouped_matmul): * [​`naive_blockwise_scaled_fp8_grouped_matmul_kernel`](./naive_blockwise_scaled_fp8_grouped_matmul_kernel): * [​`naive_blockwise_scaled_fp8_matmul`](./naive_blockwise_scaled_fp8_matmul): * [​`naive_blockwise_scaled_fp8_matmul_kernel`](./naive_blockwise_scaled_fp8_matmul_kernel): * [​`quantize_dynamic_scaled_fp8`](./quantize_dynamic_scaled_fp8): * [​`quantize_fp8_kernel`](./quantize_fp8_kernel): * [​`quantize_static_scaled_fp8`](./quantize_static_scaled_fp8):
--- ## matmul_dynamic_scaled_fp8
`matmul_dynamic_scaled_fp8[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, input_scale_granularity: StringSlice[StaticConstantOrigin], weight_scale_granularity: StringSlice[StaticConstantOrigin], m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, transpose_b: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], a_scales: NDBuffer[a_scales_type, 2, origin, shape], b_scales: NDBuffer[b_scales_type, 2, origin, shape], ctx: DeviceContext)`
--- ## naive_blockwise_scaled_fp8_grouped_matmul
`naive_blockwise_scaled_fp8_grouped_matmul[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, a_scale_layout: Layout, b_scale_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, //, BLOCK_DIM_N: Int = 32, BLOCK_DIM_M: Int = 16, transpose_b: Bool = True, scales_granularity_mnk: OptionalReg[IndexList[3]] = None, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scale_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scale_layout, MutAnyOrigin], a_offsets: LayoutTensor[a_offsets_type, a_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[expert_ids_type, expert_ids_layout, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## naive_blockwise_scaled_fp8_grouped_matmul_kernel
`naive_blockwise_scaled_fp8_grouped_matmul_kernel[c_layout: Layout, a_layout: Layout, b_layout: Layout, a_scale_layout: Layout, b_scale_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, accum_type: DType, transpose_b: Bool = True, scales_granularity_mnk: OptionalReg[IndexList[3]] = None, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_offsets: LayoutTensor[a_offsets_type, a_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[expert_ids_type, expert_ids_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scale_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scale_layout, MutAnyOrigin])`
--- ## naive_blockwise_scaled_fp8_matmul
`naive_blockwise_scaled_fp8_matmul[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, //, *, BLOCK_DIM: Int = 16, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, accum_type: DType = get_accum_type[c_type](), scales_granularity_mnk: OptionalReg[IndexList[3]] = None](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)` `naive_blockwise_scaled_fp8_matmul[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, c_shape: DimList, a_shape: DimList, b_shape: DimList, a_scale_shape: DimList, b_scale_shape: DimList, //, *, BLOCK_DIM: Int = 16, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, accum_type: DType = get_accum_type[c_type](), scales_granularity_mnk: OptionalReg[IndexList[3]] = None](c_device: NDBuffer[c_type, 2, origin, c_shape], a_device: NDBuffer[a_type, 2, origin, a_shape], b_device: NDBuffer[b_type, 2, origin, b_shape], a_scales_device: NDBuffer[a_scales_type, 2, origin, a_scale_shape], b_scales_device: NDBuffer[b_scales_type, 2, origin, b_scale_shape], ctx: DeviceContext)`
--- ## naive_blockwise_scaled_fp8_matmul_kernel
`naive_blockwise_scaled_fp8_matmul_kernel[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, accum_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scale_layout: Layout, b_scale_layout: Layout, BLOCK_DIM: Int, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, scales_granularity_mnk: OptionalReg[IndexList[3]] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scale_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scale_layout, MutAnyOrigin])`
--- ## quantize_dynamic_scaled_fp8
`quantize_dynamic_scaled_fp8[out_dtype: DType, in_dtype: DType, scales_dtype: DType, input_shape: DimList, //, group_size_or_per_token: Int](scaled_output: NDBuffer[out_dtype, 2, MutAnyOrigin], scales: NDBuffer[scales_dtype, 2, MutAnyOrigin], input: NDBuffer[in_dtype, 2, origin, input_shape], scale_ub: Float32, ctx: DeviceContext)`
--- ## quantize_fp8_kernel
`quantize_fp8_kernel[out_type: DType, scales_type: DType, in_type: DType, warps_per_block: Int, group_size: Int](output: NDBuffer[out_type, 2, MutAnyOrigin], scales: NDBuffer[scales_type, 2, MutAnyOrigin], input: NDBuffer[in_type, 2, MutAnyOrigin], scale_ub: Scalar[scales_type])`
--- ## quantize_static_scaled_fp8
`quantize_static_scaled_fp8[out_dtype: DType, in_dtype: DType, scale_is_inverted: Bool = True](out_buffer: NDBuffer[out_dtype, 2, origin, shape, strides], in_buffer: NDBuffer[in_dtype, 2, origin, shape, strides], scale: Float32, context: DeviceContext)`
--- ## GEMVAlgorithm
`struct GEMVAlgorithm` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `GEMV_KERNEL` `comptime GEMV_KERNEL = GEMVAlgorithm(0)` ### `GEMV_KERNEL_VECTOR` `comptime GEMV_KERNEL_VECTOR = GEMVAlgorithm(1)` ### `GEMV_SPLIT_K` `comptime GEMV_SPLIT_K = GEMVAlgorithm(2)` ### `GEVM_KERNEL` `comptime GEVM_KERNEL = GEMVAlgorithm(4)` ### `GEVM_KERNEL_VECTOR` `comptime GEVM_KERNEL_VECTOR = GEMVAlgorithm(3)` ### `MATMUL_NAIVE` `comptime MATMUL_NAIVE = GEMVAlgorithm(5)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__is__` `__is__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__isnot__` `__isnot__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` Returns the string representation of this algorithm. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String: A human-readable string representation of the algorithm. ### `write_to` `write_to(self, mut writer: T)`
--- ## gemv
`gemv[parallelize: Bool, c_size: Dim, c_type: DType, a_shape: DimList, a_type: DType, b_size: Dim, b_type: DType, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c_buf: NDBuffer[c_type, 1, origin, c_size], a_buf: NDBuffer[a_type, 2, origin, a_shape], b_buf: NDBuffer[b_type, 1, origin, b_size])`
--- ## gemv_gpu
`gemv_gpu[transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], ctx: DeviceContext)`
--- ## gemv_gpu_dispatch
`gemv_gpu_dispatch[transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, pdl_level: PDLLevel = PDLLevel()](kernel_func: GEMVAlgorithm, c: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], ctx: DeviceContext)`
--- ## gemv_kernel
`gemv_kernel[c_type: DType, a_type: DType, b_type: DType, *, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[c_type](), pdl_level: PDLLevel = PDLLevel()](c: LegacyUnsafePointer[Scalar[c_type]], a: LegacyUnsafePointer[Scalar[a_type]], b: LegacyUnsafePointer[Scalar[b_type]], m: Int, n: Int, k: Int)`
--- ## gemv_kernel_vector
`gemv_kernel_vector[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, *, simd_width: UInt, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[c_type](), pdl_level: PDLLevel = PDLLevel()](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], m: Int, n: Int, k: Int)`
--- ## gemv_split_k
`gemv_split_k[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, simd_width: UInt, tile_m: UInt, tile_n: UInt, num_threads: UInt, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[c_type](), check_bounds: Bool = True, pdl_level: PDLLevel = PDLLevel()](output: LayoutTensor[c_type, c_layout, MutAnyOrigin], act: LayoutTensor[a_type, a_layout, MutAnyOrigin], weight: LayoutTensor[b_type, b_layout, MutAnyOrigin], m: Int, n: Int, k: Int)` GEMV with tiling in K dimension. Assuming the B (weight) matrix is transposed i.e. row major N x K, this kernel implements a vector (1 x K) times a matrix (N x K). The impl can actually handle M > 1 but it's only optimal for tiny M. We use it for M = 1 only.
--- ## gevm_kernel
`gevm_kernel[c_type: DType, a_type: DType, b_type: DType, *, tile_size: Int, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[c_type](), pdl_level: PDLLevel = PDLLevel()](c: LegacyUnsafePointer[Scalar[c_type]], a: LegacyUnsafePointer[Scalar[a_type]], b: LegacyUnsafePointer[Scalar[b_type]], m: Int, n: Int, k: Int)`
--- ## gemv (Gemv)
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Structs * [​`GEMVAlgorithm`](./GEMVAlgorithm): ## Functions * [​`gemv`](./gemv): * [​`gemv_gpu`](./gemv_gpu): * [​`gemv_gpu_dispatch`](./gemv_gpu_dispatch): * [​`gemv_kernel`](./gemv_kernel): * [​`gemv_kernel_vector`](./gemv_kernel_vector): * [​`gemv_split_k`](./gemv_split_k): GEMV with tiling in K dimension. Assuming the B (weight) matrix is transposed i.e. row major N x K, this kernel implements a vector (1 x K) times a matrix (N x K). The impl can actually handle M > 1 but it's only optimal for tiny M. We use it for M = 1 only. * [​`gevm_kernel`](./gevm_kernel): * [​`log_shape`](./log_shape): * [​`naive_gemv`](./naive_gemv): * [​`reverse_idx`](./reverse_idx):
--- ## log_shape
`log_shape[has_mode_1: Bool, has_mode_2: Bool, name: String](mode_1: Int, mode_2: Int)`
--- ## naive_gemv
`naive_gemv[c_size: Dim, a_shape: DimList, b_size: Dim, dtype: DType](c_buf: NDBuffer[dtype, 1, origin, c_size], a_buf: NDBuffer[dtype, 2, origin, a_shape], b_buf: NDBuffer[dtype, 1, origin, b_size])`
--- ## reverse_idx
`reverse_idx[transpose: Bool](x: Int, y: Int) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## dispatch_amd_matmul_by_block_shape
`dispatch_amd_matmul_by_block_shape[c_type: DType, a_type: DType, b_type: DType, transpose_b: Bool, N: Int, K: Int, launcher_fn: fn[config: MatmulConfig[a_type, b_type, c_type, transpose_b]]() raises capturing -> None, default_block_tile_shape: IndexList[3], use_heuristic: Bool = False](M: Int, ctx: DeviceContext)` Dispatches to the best kernel configuration based on runtime M dimension.
--- ## grouped_matmul
`grouped_matmul[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_amd
`grouped_matmul_amd[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, transpose_b: Bool = True, block_tile_shape: IndexList[3] = Index(128, 128, 64), elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_amd_kernel_launcher
`grouped_matmul_amd_kernel_launcher[c_type: DType, a_type: DType, b_type: DType, layout_c: Layout, layout_a: Layout, layout_b: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c_tensor: LayoutTensor[c_type, layout_c, MutAnyOrigin], a_tensor: LayoutTensor[a_type, layout_a, MutAnyOrigin], b_tensor: LayoutTensor[b_type, layout_b, MutAnyOrigin], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int)`
--- ## grouped_matmul_kernel_sm100
`grouped_matmul_kernel_sm100[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, a_tile_layout: Layout, b_tile_layout: Layout, c_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], a_desc_layout: Layout, b_desc_layout: Layout, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, transpose_b: Bool = True, num_threads: Int = 128, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], num_iters: Int)`
--- ## grouped_matmul_sm100
`grouped_matmul_sm100[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool = True, mma_shape: IndexList[3] = Index(64, 128, 16), block_tile_shape: IndexList[3] = Index(64, 128, 64), elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_vendor
`grouped_matmul_vendor[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, transpose_b: Bool = True, use_tf32: Bool = False](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul (Grouped_matmul)
## Functions * [​`dispatch_amd_matmul_by_block_shape`](./dispatch_amd_matmul_by_block_shape): Dispatches to the best kernel configuration based on runtime M dimension. * [​`grouped_matmul`](./grouped_matmul): * [​`grouped_matmul_amd`](./grouped_matmul_amd): * [​`grouped_matmul_amd_kernel_launcher`](./grouped_matmul_amd_kernel_launcher): * [​`grouped_matmul_kernel_sm100`](./grouped_matmul_kernel_sm100): * [​`grouped_matmul_sm100`](./grouped_matmul_sm100): * [​`grouped_matmul_vendor`](./grouped_matmul_vendor): * [​`naive_epilogue`](./naive_epilogue): * [​`naive_epilogue_kernel`](./naive_epilogue_kernel): * [​`naive_grouped_matmul`](./naive_grouped_matmul): * [​`naive_grouped_matmul_kernel`](./naive_grouped_matmul_kernel):
--- ## naive_epilogue
`naive_epilogue[c_type: DType, c_shape: DimList, *, elementwise_lambda_fn: elementwise_epilogue_type](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], ctx: DeviceContext)`
--- ## naive_epilogue_kernel
`naive_epilogue_kernel[c_type: DType, c_shape: DimList, *, elementwise_lambda_fn: elementwise_epilogue_type](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape])`
--- ## naive_grouped_matmul
`naive_grouped_matmul[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool = True, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## naive_grouped_matmul_kernel
`naive_grouped_matmul_kernel[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin])`
--- ## WarpRole
`@register_passable(trivial)` `struct WarpRole` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Epilogue` `comptime Epilogue = WarpRole(3)` ### `MainLoad` `comptime MainLoad = WarpRole(4)` ### `Mma` `comptime Mma = WarpRole(5)` ## Methods ### `__eq__` `__eq__(self, other: UInt) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ge__` `__ge__(self, other: UInt) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_main_load` `static is_main_load() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_mma` `static is_mma() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_epilogue` `static is_epilogue() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## blackwell_tma_umma_warp_specialized_kernel
`blackwell_tma_umma_warp_specialized_kernel[a_type: DType, b_type: DType, c_type: DType, expert_m: Int, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_tensor_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cluster_shape: StaticTuple[Int32, 3], num_pipeline_stages: UInt, num_accum_pipeline_stages: UInt, num_output_stages: UInt = 2, output_tile_shape: IndexList[2] = Index(128, 32), transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 2, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, transpose_c: Bool = False](num_active_experts: Int, a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], b_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], mnk: StaticTuple[UInt32, 3])`
--- ## consumer_main_loop
`consumer_main_loop[accum_type: DType, c_type: DType, a_type: DType, b_type: DType, a_smem_layout: Layout, b_smem_layout: Layout, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool, pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1)](tmem_addr: UInt32, a_smem_iter: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], mma_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tma_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], consumer_phase: PipelineState[pipeline_stages], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32)`
--- ## grouped_matmul_sm100_persistent
`grouped_matmul_sm100_persistent[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cta_group: Int = 1, num_pipeline_stages: Optional[UInt] = None, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_sm100 (Grouped_matmul_sm100)
## Structs * [​`WarpRole`](./WarpRole): ## Functions * [​`blackwell_tma_umma_warp_specialized_kernel`](./blackwell_tma_umma_warp_specialized_kernel): * [​`consumer_main_loop`](./consumer_main_loop): * [​`grouped_matmul_sm100_persistent`](./grouped_matmul_sm100_persistent): * [​`load_AB`](./load_AB): * [​`multi_stage_store_C`](./multi_stage_store_C): * [​`stsm_helper`](./stsm_helper): * [​`zero_output`](./zero_output):
--- ## load_AB
`load_AB[a_type: DType, b_type: DType, a_layout: Layout, b_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, num_pipeline_stages: UInt, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], mma_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tma_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_phase: PipelineState[Int(num_pipeline_stages)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool, scheduler: TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB])`
--- ## multi_stage_store_C
`multi_stage_store_C[c_type: DType, c_smem_layout: Layout, c_layout: Layout, c_tensor_layout: Layout, c_desc_layout: Layout, num_accum_pipeline_stages: UInt, /, *, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], stage_stride_cols: UInt, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 1, num_output_warps: UInt = 4, max_tmem_cols: UInt = 512, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, transpose_c: Bool = False](c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], accum_pipeline_consumer_state: PipelineState[Int(num_accum_pipeline_stages)], accum_full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], accum_empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tmem_addr: UInt32, work_tile_coord: Tuple[UInt, UInt], group_end_idx: UInt32, elect_one_warp: Bool, M: UInt32, N: UInt32)`
--- ## stsm_helper
`stsm_helper[swizzle: Swizzle, transpose_c: Bool = False](vec: SIMD[dtype, size], dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## zero_output
`zero_output[c_type: DType, c_layout: Layout, *, output_tile_shape: IndexList[2]](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], coord: Tuple[UInt32, UInt32], group_end_idx: UInt32)`
--- ## blackwell_gmm_tma_umma_warp_specialized_blockwise_fp8_kernel
`blackwell_gmm_tma_umma_warp_specialized_blockwise_fp8_kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_tensor_layout: Layout, a_scales_tile_layout: Layout, a_scales_type: DType, a_offsets_layout: Layout, b_scales_type: DType, b_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, a_scales_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], num_pipeline_stages: UInt, cluster_shape: StaticTuple[Int32, 3], expert_n: Int, expert_ids_layout: Layout, b_scales_n: Int](num_active_experts: Int, a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_layout, a_scales_desc_layout], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], num_iters: UInt, b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], problem_shape: StaticTuple[Int32, 3])`
--- ## grouped_matmul_dynamic_scaled_fp8
`grouped_matmul_dynamic_scaled_fp8[c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, //, input_scale_granularity: StringSlice[StaticConstantOrigin], weight_scale_granularity: StringSlice[StaticConstantOrigin], m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, transpose_b: Bool = False, tokens_padded_per_expert: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[c_type, 2, MutAnyOrigin, shape], a: NDBuffer[a_type, 2, MutAnyOrigin, shape], b: NDBuffer[b_type, 3, MutAnyOrigin, shape], a_scales: NDBuffer[a_scales_type, 2, MutAnyOrigin, shape], b_scales: NDBuffer[b_scales_type, 3, MutAnyOrigin, shape], a_offsets: NDBuffer[a_offsets_type, 1, MutAnyOrigin, shape], expert_ids: NDBuffer[expert_ids_type, 1, MutAnyOrigin, shape], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_sm100_blockwise_scaled_fp8
`grouped_matmul_sm100_blockwise_scaled_fp8[a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, transpose_b: Bool, //, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_offsets: LayoutTensor[a_offsets_type, a_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_ids: LayoutTensor[expert_ids_type, expert_ids_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_sm100_blockwise_scaled_fp8_persistent
`grouped_matmul_sm100_blockwise_scaled_fp8_persistent[a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, a_offsets_type: DType, expert_ids_type: DType, transpose_b: Bool, //, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_offsets: LayoutTensor[a_offsets_type, a_offsets_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_ids: LayoutTensor[expert_ids_type, expert_ids_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul_sm100_blockwise_fp8
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Functions * [​`blackwell_gmm_tma_umma_warp_specialized_blockwise_fp8_kernel`](./blackwell_gmm_tma_umma_warp_specialized_blockwise_fp8_kernel): * [​`grouped_matmul_dynamic_scaled_fp8`](./grouped_matmul_dynamic_scaled_fp8): * [​`grouped_matmul_sm100_blockwise_scaled_fp8`](./grouped_matmul_sm100_blockwise_scaled_fp8): * [​`grouped_matmul_sm100_blockwise_scaled_fp8_persistent`](./grouped_matmul_sm100_blockwise_scaled_fp8_persistent): * [​`load_AB`](./load_AB): * [​`matmul_sm100_grouped_blockwise_scaled_fp8_1d2d_kernel`](./matmul_sm100_grouped_blockwise_scaled_fp8_1d2d_kernel): * [​`multi_stage_reg_epilogue`](./multi_stage_reg_epilogue): * [​`promote_accumulators`](./promote_accumulators):
--- ## load_AB (Grouped_matmul_sm100_blockwise_fp8)
`load_AB[a_type: DType, b_type: DType, a_scales_type: DType, a_layout: Layout, b_layout: Layout, a_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, a_scales_smem_layout: Layout, num_pipeline_stages: UInt, expert_ids_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_layout, a_scales_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], a_scales_smem: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[Int(num_pipeline_stages)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt, elect_one_cta: Bool, scheduler: TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin])`
--- ## matmul_sm100_grouped_blockwise_scaled_fp8_1d2d_kernel
`matmul_sm100_grouped_blockwise_scaled_fp8_1d2d_kernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, accum_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_offsets_layout: Layout, expert_ids_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_tile_layout: Layout, b_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: UInt = 128, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], a_offsets: LayoutTensor[DType.uint32, a_offsets_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a_scales: LayoutTensor[a_scales_type, a_scales_layout, MutAnyOrigin], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], num_iters: UInt)`
--- ## multi_stage_reg_epilogue
`multi_stage_reg_epilogue[c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, accum_type: DType, accum_layout: Layout, c_tensor_layout: Layout, /, *, c_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_warps: UInt, c_swizzle: TensorMapSwizzle](c_upper_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_lower_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c: LayoutTensor[c_type, c_tensor_layout, MutAnyOrigin], c_coord: Tuple[UInt, UInt], elect_one_warp: Bool, group_end_idx: UInt32)`
--- ## promote_accumulators
`promote_accumulators[pipeline_stages: UInt, num_accum_pipeline_stages: UInt, accum_type: DType, accum_layout: Layout, a_scales_type: DType, b_scales_type: DType, b_scales_layout: Layout, a_scales_smem_layout: Layout, expert_ids_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int, CLUSTER_SIZE: Int32, is_lower_frag_required: Bool, num_output_warps: UInt](b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], b_scales_n: Int, a_scales_smem_iter: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_upper_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_lower_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_output_pipeline: ProducerConsumerPipeline[Int(num_accum_pipeline_stages)], tmem_addr: UInt32, load_mma_pipeline: ProducerConsumerPipeline[Int(pipeline_stages)], work_tile_coord: Tuple[UInt, UInt], elect_one_warp: Bool, stage_stride_cols: UInt, k_iter: UInt, problem_shape: StaticTuple[Int32, 3], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], scheduler: TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB])`
--- ## RasterOrder
`@register_passable(trivial)` `struct RasterOrder` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AlongM` `comptime AlongM = RasterOrder(1)` ### `AlongN` `comptime AlongN = RasterOrder(0)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## TileScheduler
`@register_passable(trivial)` `struct TileScheduler[offsets_layout: Layout, //, *, static_MN: Int, tile_shape: IndexList[3], cluster: IndexList[3] = Index(1, 1, 1), cta_group: Int = 1, swizzle: Bool = False, swapAB: Bool = True]` ## Fields * ​num\_active\_experts (`Int`): * ​group\_offsets (`LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin]`): * ​current\_iter (`Int32`): * ​current\_group\_idx (`UInt32`): * ​current\_dynamic\_dim\_cumsum (`UInt32`): * ​block\_idx\_start (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `cta_group_tile_shape` `comptime cta_group_tile_shape = Index((tile_shape.__getitem__[3, DType.int64, Int](0) * cta_group), (tile_shape.__getitem__[3, DType.int64, Int](1) * cta_group))` ### `div_dynamic_block` `comptime div_dynamic_block = FastDiv[DType.uint32](TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB].cta_group_tile_shape.__getitem__[2, DType.int64, Int](TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB].dynamic_dim))` ### `dynamic_dim` `comptime dynamic_dim = 1 if swapAB else 0` ### `kNum1DBlocksPerGroup` `comptime kNum1DBlocksPerGroup = 16` ### `num_static_dim_blocks` `comptime num_static_dim_blocks = SIMD[DType.uint32, 1](ceildiv(static_MN, tile_shape.__getitem__[3, DType.int64, Int](TileScheduler[static_MN=static_MN, tile_shape=tile_shape, cluster=cluster, cta_group=cta_group, swizzle=swizzle, swapAB=swapAB].static_dim)))` ### `static_dim` `comptime static_dim = 0 if swapAB else 1` ## Methods ### `__init__` `__init__(num_active_experts: Int, group_offsets: LayoutTensor[DType.uint32, offsets_layout, MutAnyOrigin]) -> Self` ### `fetch_next_work` `fetch_next_work(mut self) -> WorkInfo` **Returns:** `WorkInfo`
--- ## WorkInfo
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​is\_valid\_tile (`Bool`): * ​terminate (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_done` `is_done(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## grouped_matmul_tile_scheduler
## Structs * [​`RasterOrder`](./RasterOrder): * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo):
--- ## linalg
Provides CPU and GPU implementations of linear algebra functions. ## Packages * [​`arch`](./arch/): Provides architecture specific utility functions. * [​`matmul`](./matmul/): Provides the backend implementation for matmuls. ## Modules * [​`accumulate`](./accumulate/): * [​`bmm`](./bmm/): * [​`distributed_matmul`](./distributed_matmul/): * [​`dual_gemm`](./dual_gemm/): * [​`fp4_quantization`](./fp4_quantization/): * [​`fp4_utils`](./fp4_utils/): * [​`fp8_quantization`](./fp8_quantization/): * [​`gemv`](./gemv/): * [​`grouped_matmul`](./grouped_matmul/): * [​`grouped_matmul_sm100`](./grouped_matmul_sm100/): * [​`grouped_matmul_sm100_blockwise_fp8`](./grouped_matmul_sm100_blockwise_fp8/): * [​`grouped_matmul_tile_scheduler`](./grouped_matmul_tile_scheduler/): * [​`lora`](./lora/): * [​`matrix_band_part`](./matrix_band_part/): The module implements matrix band part functions. * [​`packing`](./packing/): * [​`qr_factorization`](./qr_factorization/): * [​`structuring`](./structuring/): * [​`transpose`](./transpose/): The module implements Transpose functions. * [​`utils`](./utils/): * [​`utils_gpu`](./utils_gpu/):
--- ## lora
## Functions * [​`shrink_qkv_permute_3mn_sm100`](./shrink_qkv_permute_3mn_sm100): LoRA shrink GMM with planar Q/K/V output on SM100.
--- ## shrink_qkv_permute_3mn_sm100
`shrink_qkv_permute_3mn_sm100[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList](c_lora: NDBuffer[c_type, 3, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, num_active_experts: Int, ctx: DeviceContext)` LoRA shrink GMM with planar Q/K/V output on SM100. Performs the LoRA 'shrink' grouped matmul for routed tokens: computes `[M, K] @ [G, 3N, K]^T` per active expert, then **permutes** the flat `[M, 3N]` result into a planar layout `[3, M, N]` (Q, K, V) using an elementwise epilogue, while reusing the same storage. **Constraints:** * c\_lora must be rank 3 with static first dimension B == 3. * a must be rank 2 with trailing dimension K that matches b\[..., K]. * b must be rank 3 with shape (G, 3N, K). * The temporary 2D view of c\_lora is (M, 3N) in row-major order and **aliases the same storage** as c\_lora. * a\_offsets is non-decreasing with a\_offsets\[0] == 0 and a\_offsets\[num\_active\_experts] == M. * expert\_ids\[i] ∈ \[0, G) for valid experts; kernel may treat -1 as inactive. * The epilogue assumes `N % vector_width == 0` for aligned vector stores. **Args:** * ​c\_lora ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Output tensor with planar Q/K/V layout, shape (3, M, N). Backed by row-major storage, used both as a 3D view and as a temporary 2D view (M, 3N) during compute. * ​a ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Routed activation matrix, shape (M, K). * ​b ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Shrink weights per expert, shape (G, 3N, K). * ​a\_offsets ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Inclusive prefix sums of tokens per (active) expert, length (num\_experts + 1). Defines per-expert \[start, end) in A/C. * ​expert\_ids ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Expert indices for the active groups, length ≥ num\_active\_experts. * ​max\_num\_tokens\_per\_expert ([`Int`](/mojo/stdlib/builtin/int/Int)): Upper bound on tokens for any active expert. * ​num\_active\_experts ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of experts participating in this call. * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): DeviceContext used for enqueues and synchronization.
--- ## apple_batched_matmul
`apple_batched_matmul[*, transpose_b: Bool = False, elementwise_epilogue_fn: OptionalReg[fn[c_type: DType, width: Int, rank: Int, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None] = None](c: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive])`
--- ## apple_gemv
`apple_gemv[*, b_packed: Bool, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape])`
--- ## apple_matmul
`apple_matmul[*, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](cblas_gemm_fn: cblas_gemm_type, c: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive])` `apple_matmul[*, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, rank, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive])`
--- ## get_cblas_f32_function
`get_cblas_f32_function() -> cblas_gemm_type` **Returns:** [`cblas_gemm_type`](/mojo/kernels/linalg/matmul/cpu/apple_accelerate/#cblas_gemm_type)
--- ## apple_accelerate
## `comptime` values ### `APPLE_ACCELERATE` `comptime APPLE_ACCELERATE = _Global["APPLE_ACCELERATE", _init_dylib, _on_error_msg]` ### `cblas_gemm_type` `comptime cblas_gemm_type = fn(_CBLASOrder, _CBLASTranspose, _CBLASTranspose, Int32, Int32, Int32, Float32, LegacyUnsafePointer[Float32], Int32, LegacyUnsafePointer[Float32], Int32, Float32, LegacyUnsafePointer[Float32], Int32) -> None` ### `LIB_ACC_PATH` `comptime LIB_ACC_PATH = "/System/Library/Frameworks/Accelerate.framework/Accelerate"` ## Functions * [​`apple_batched_matmul`](./apple_batched_matmul): * [​`apple_gemv`](./apple_gemv): * [​`apple_matmul`](./apple_matmul): * [​`get_cblas_f32_function`](./get_cblas_f32_function): * [​`use_apple_accelerate_lib`](./use_apple_accelerate_lib):
--- ## use_apple_accelerate_lib
`use_apple_accelerate_lib[c_type: DType, a_type: DType, b_type: DType]() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## Inner_matmul_default
`struct Inner_matmul_default` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`InnerMatmulKernel`](/mojo/kernels/linalg/matmul/cpu/impl/InnerMatmulKernel), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` Utility function on the inner loop. Run the inner kernel on the whole (kernel\_rows, TileN, TileK) tile.
--- ## default
## Structs * [​`Inner_matmul_default`](./Inner_matmul_default):
--- ## Inner_matmul_i8mm
`struct Inner_matmul_i8mm` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`InnerMatmulKernel`](/mojo/kernels/linalg/matmul/cpu/impl/InnerMatmulKernel), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` Utility function on the inner loop. Run the inner kernel on the whole (kernel\_rows2, TileN, TileK) tile.
--- ## LoadStore_i8mm
`struct LoadStore_i8mm[dtype: DType, simd_size: Int, single_row: Bool, tile_rows: Int, tile_columns: Int]` ## Fields * ​output\_tile (`_Accumulator[dtype, tile_rows, LoadStore_i8mm[dtype, simd_size, single_row, tile_rows, tile_columns].num_simd_cols, simd_size]`): * ​skip\_boundary\_check (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `num_simd_cols` `comptime num_simd_cols = (tile_columns // simd_size)` ## Methods ### `__init__` `__init__(out self, skip_boundary_check: Bool)`
--- ## i8mm
## Structs * [​`Inner_matmul_i8mm`](./Inner_matmul_i8mm): * [​`LoadStore_i8mm`](./LoadStore_i8mm):
--- ## InnerMatmulKernel
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ## Provided methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self: _Self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## TiledMatmul
`struct TiledMatmul[a_mut: Bool, b_mut: Bool, //, config: KernelConfig, transpose_b: Bool, b_packed: Bool, elementwise_epilogue_enabled: Bool, kernel_id: InnerKernelID, a_type: DType, a_shape: DimList, a_origin: Origin[a_mut], b_type: DType, b_shape: DimList, b_origin: Origin[b_mut], c_type: DType, c_shape: DimList, c_origin: MutOrigin, algorithm: InnerMatmulKernel]` Tiled matmul implementation integrating packing, inner loop and tile partitions. TODO: add tag based implementation dispatch. TODO: add fusion hooks. ## Fields * ​alg (`algorithm`): * ​c (`NDBuffer[c_type, 2, c_origin, c_shape]`): * ​a (`NDBuffer[a_type, 2, a_origin, a_shape]`): * ​b (`NDBuffer[b_type, 2, b_origin, b_shape]`): * ​tile\_n\_k (`IndexList[2]`): * ​global\_tile\_offset (`GemmShape`): * ​global\_tile\_shape (`GemmShape`): * ​b\_tile\_generator (`BTileGenerator[config, a_type, b_type, c_type, b_shape, transpose_b, b_packed, b_origin]`): * ​elementwise\_epilogue\_fn (`fn(GemmShape, GemmShape) escaping -> None`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = fn(GemmShape, GemmShape) escaping -> None.__copyinit__is_trivial if True if True if True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial else True if algorithm.__copyinit__is_trivial else algorithm.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = fn(GemmShape, GemmShape) escaping -> None.__del__is_trivial if True if True if True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if True if algorithm.__del__is_trivial else algorithm.__del__is_trivial else True if algorithm.__del__is_trivial else algorithm.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = fn(GemmShape, GemmShape) escaping -> None.__moveinit__is_trivial if True if True if True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial else True if algorithm.__moveinit__is_trivial else algorithm.__moveinit__is_trivial`
--- ## elementwise_epilogue_c_tile
`elementwise_epilogue_c_tile[simd_width: Int, dtype: DType, origin: MutOrigin, c_shape: DimList, func: fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None](offset: GemmShape, tile_len: GemmShape, c: NDBuffer[dtype, 2, origin, c_shape])`
--- ## impl
## Structs * [​`TiledMatmul`](./TiledMatmul): Tiled matmul implementation integrating packing, inner loop and tile partitions. ## Traits * [​`InnerMatmulKernel`](./InnerMatmulKernel): ## Functions * [​`elementwise_epilogue_c_tile`](./elementwise_epilogue_c_tile): * [​`matmul`](./matmul): * [​`tiled_matmul_run`](./tiled_matmul_run): Interface function to run tiled matmul on a given sub-tile.
--- ## matmul
`matmul[*, transpose_b: Bool = False, b_packed: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False](c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], kernel_type_m: Int, num_threads: Int = -1)`
--- ## tiled_matmul_run
`tiled_matmul_run[config: KernelConfig, transpose_b: Bool, b_packed: Bool, simd_size: Int, elementwise_epilogue_enabled: Bool, kernel_id: InnerKernelID, algorithm: InnerMatmulKernel](alg: algorithm, c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], elementwise_epilogue_fn: fn(GemmShape, GemmShape) escaping -> None, global_tile_shape: GemmShape, global_tile_offset: GemmShape)` Interface function to run tiled matmul on a given sub-tile. **Args:** * ​alg (`algorithm`): InnerMatmulKernel algorithm for microkernel. * ​c ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Pre-allocated buffer space for result. * ​a ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Operand A of the matmul. * ​b ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): Operand B of the mamtul. * ​elementwise\_epilogue\_fn (\[`fn(GemmShape, GemmShape) escaping -> None`]\(/mojo/kernels/linalg/matmul/cpu/impl/fn(GemmShape, GemmShape) escaping -> None)): The elementwise epilogue function. * ​global\_tile\_shape ([`GemmShape`](/mojo/kernels/linalg/utils/GemmShape)): Tile shape this call will process. * ​global\_tile\_offset ([`GemmShape`](/mojo/kernels/linalg/utils/GemmShape)): Tile offset on the original buffer.
--- ## cpu (Cpu)
Provides the CPU backend implementations for matmuls. ## Modules * [​`apple_accelerate`](./apple_accelerate/): * [​`default`](./default/): * [​`i8mm`](./i8mm/): * [​`impl`](./impl/): * [​`neon`](./neon/): * [​`vnni`](./vnni/):
--- ## Inner_matmul_neon
`struct Inner_matmul_neon` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`InnerMatmulKernel`](/mojo/kernels/linalg/matmul/cpu/impl/InnerMatmulKernel), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` Utility function on the inner loop. Run the inner kernel on the whole (kernel\_rows, TileN, TileK) tile.
--- ## neon
## Structs * [​`Inner_matmul_neon`](./Inner_matmul_neon):
--- ## Inner_matmul_vnni
`struct Inner_matmul_vnni[saturated_vnni: Bool]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`InnerMatmulKernel`](/mojo/kernels/linalg/matmul/cpu/impl/InnerMatmulKernel), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__inner_matmul__` `__inner_matmul__[kernel_rows: Int, kernel_cols: Int, simd_size: Int](self, c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], global_offset: GemmShape, global_bound: GemmShape, tile_n_k: IndexList[2], skip_boundary_check: Bool)` Utility function on the inner loop. Run the inner kernel on the whole (kernel\_rows, TileN, TileK) tile.
--- ## vnni
## Structs * [​`Inner_matmul_vnni`](./Inner_matmul_vnni):
--- ## amd
Provides the AMD GPU backend implementations for matmuls. ## Modules * [​`matmul`](./matmul/): * [​`pingpong_kernel`](./pingpong_kernel/): * [​`ring_buffer`](./ring_buffer/): Ring Buffer implementation for producer-consumer synchronization in GPU kernels. * [​`ring_buffer_traits`](./ring_buffer_traits/): Trait definitions and utilities for ring buffer synchronization strategies. * [​`structured`](./structured/): * [​`warp_spec_matmul`](./warp_spec_matmul/): AMD Warp-Specialized Matrix Multiplication
--- ## MMATileBuffers
`struct MMATileBuffers[mut: Bool, dtype: DType, layout: Layout, origin: Origin[mut], address_space: AddressSpace, element_layout: Layout, layout_int_type: DType, linear_idx_type: DType, masked: Bool, alignment: Int, //, _dtype: DType, /, smem_layout: Layout, reg_tile_layout: Layout, tensor_type: AnyStruct[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]], thread_layout: Layout, warp_rows: Int, warp_cols: Int, swizzle: Swizzle]` Manages memory for a single matrix (A or B) in GEMM computation. This struct encapsulates all memory handling for a matrix, including: * Shared memory allocation and tiling * Register buffer allocation * Data movement between memory levels (DRAM→local→shared) ## Fields * ​smem\_tile (`MMATileBuffers[_dtype, smem_layout, reg_tile_layout, tensor_type, thread_layout, warp_rows, warp_cols, swizzle].SMemTileType`): * ​smem\_warp\_tile (`LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, smem_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(smem_layout, AddressSpace.SHARED), _get_index_type(smem_layout, AddressSpace.SHARED), False, align_of[_dtype](), warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(smem_layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(smem_layout, AddressSpace.SHARED), masked=_tile_is_masked[smem_layout, warp_rows, warp_cols]()]`): * ​load\_reg\_tile (`MMATileBuffers[_dtype, smem_layout, reg_tile_layout, tensor_type, thread_layout, warp_rows, warp_cols, swizzle].MMARegTileType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `MMARegTileType` `comptime MMARegTileType = LayoutTensor[_dtype, reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `SMemTileType` `comptime SMemTileType = LayoutTensor[_dtype, smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]` ## Methods ### `__init__` `__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_idx: Int, warp_k_idx: Int, block_idx: Int)` Initialize memory regions for a matrix based on warp coordinates. **Args:** * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to load from global memory. * ​warp\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The warp index within the computation grid (used for MMA operations). * ​warp\_k\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The warp index within the computation grid (used for MMA operations). * ​block\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The block index within the computation grid (used for warp tiling). ### `copy_to_smem` `copy_to_smem(self)` Copy data from thread-local memory to shared memory. Uses structured thread cooperation to efficiently transfer data.
--- ## MmaOpAMD
`struct MmaOpAMD[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool, k_group_size: Int, num_k_tiles: Int, num_m_mmas: Int, num_n_mmas: Int, out_frag_size: Int, swizzle: Swizzle]` ## Fields * ​out\_reg\_tile (`MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].OutRegTileType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `alignment` `comptime alignment = align_of[SIMD[in_type, MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width]]()` ### `out_reg_layout` `comptime out_reg_layout = Layout.row_major((num_m_mmas * num_n_mmas), out_frag_size)` ### `OutRegTileType` `comptime OutRegTileType = LayoutTensor[out_type, MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].out_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `reg_tile_layout` `comptime reg_tile_layout[num_mmas: Int] = Layout.row_major((num_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width)` #### Parameters * ​num\_mmas ([`Int`](/mojo/stdlib/builtin/int/Int)): ### `RegTileType` `comptime RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major((num_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` #### Parameters * ​num\_mmas ([`Int`](/mojo/stdlib/builtin/int/Int)): ### `simd_width` `comptime simd_width = simd_width_of[in_type]()` ### `tensor_core_mma` `comptime tensor_core_mma = TiledTensorCore[out_type, in_type, shape, k_group_size, transpose_b]()` ## Methods ### `__init__` `__init__(out self)` ### `a_reg_tile` `a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), False, align_of[in_type](), num_m_mmas, simd_width_of[in_type]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((num_m_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), num_m_mmas, simd_width_of[in_type]()]()]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `b_reg_tile` `b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), False, align_of[in_type](), num_n_mmas, simd_width_of[in_type]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((num_n_mmas * num_k_tiles), MmaOpAMD[out_type, in_type, shape, transpose_b, k_group_size, num_k_tiles, num_m_mmas, num_n_mmas, out_frag_size, swizzle].simd_width), num_n_mmas, simd_width_of[in_type]()]()]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `mma` `mma[k_tile_idx: Int](self)` ### `load_tile_fragment` `load_tile_fragment[k_tile_idx: Int](self, a_smem_tiles: LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(layout, AddressSpace.SHARED), _get_index_type(layout, AddressSpace.SHARED), False, align_of[_dtype](), warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()], b_smem_tiles: LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(layout, AddressSpace.SHARED), _get_index_type(layout, AddressSpace.SHARED), False, align_of[_dtype](), warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()])` ### `reset_accumulator` `reset_accumulator(self)`
--- ## gemm_kernel_amd
`gemm_kernel_amd[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, c_layout_int_type: DType, a_layout_int_type: DType, b_layout_int_type: DType, c_linear_idx_type: DType, a_linear_idx_type: DType, b_linear_idx_type: DType, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin, layout_int_type=c_layout_int_type, linear_idx_type=c_linear_idx_type], a: LayoutTensor[a_type, a_layout, MutAnyOrigin, layout_int_type=a_layout_int_type, linear_idx_type=a_linear_idx_type], b: LayoutTensor[b_type, b_layout, MutAnyOrigin, layout_int_type=b_layout_int_type, linear_idx_type=b_linear_idx_type])` AMD-optimized GEMM kernel for matrix multiplication C = A \* B. This kernel implements an efficient matrix multiplication algorithm optimized for AMD GPUs, with hierarchical tiling and structured memory access patterns. **Parameters:** * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the output matrix C. * ​c\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Memory layout for matrix C. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the input matrix A. * ​a\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Memory layout for matrix A. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the input matrix B. * ​b\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Memory layout for matrix B. * ​transpose\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether matrix B should be transposed. * ​c\_layout\_int\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the integer part of matrix C. * ​a\_layout\_int\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the integer part of matrix A. * ​b\_layout\_int\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the integer part of matrix B. * ​c\_linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the linear index of matrix C. * ​a\_linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the linear index of matrix A. * ​b\_linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the linear index of matrix B. * ​config ([`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)): GEMM configuration parameters (tile sizes, etc.). * ​elementwise\_lambda\_fn ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional function to apply to output elements. **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix C (result). * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix A. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix B (must be transposed).
--- ## matmul (Matmul)
## `comptime` values ### `SMemWarpTileType` `comptime SMemWarpTileType[_dtype: DType, layout: Layout, warp_rows: Int, warp_cols: Int] = LayoutTensor[_dtype, LayoutTensor._compute_tile_layout[True, _dtype, layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(layout, AddressSpace.SHARED), _get_index_type(layout, AddressSpace.SHARED), False, align_of[_dtype](), warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(layout, AddressSpace.SHARED), masked=_tile_is_masked[layout, warp_rows, warp_cols]()]` Type alias for warp-level shared memory tiles with specified dimensions. #### Parameters * ​\_dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): * ​warp\_rows ([`Int`](/stdlib/builtin/int/Int)): * ​warp\_cols ([`Int`](/stdlib/builtin/int/Int)): ## Structs * [​`MmaOpAMD`](./MmaOpAMD): * [​`MMATileBuffers`](./MMATileBuffers): Manages memory for a single matrix (A or B) in GEMM computation. ## Functions * [​`gemm_kernel_amd`](./gemm_kernel_amd): AMD-optimized GEMM kernel for matrix multiplication C = A \* B. * [​`write_output_fragments`](./write_output_fragments): Write output fragments from registers to global memory with optional elementwise operations.
--- ## write_output_fragments
`write_output_fragments[c_type: DType, c_frag_size: Int, MMA_M: Int, MMA_N: Int, output_thread_layout: Layout, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c_reg_fragment: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_gmem_fragment: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_tile_m: Int, warp_tile_n: Int, M: Int, N: Int)` Write output fragments from registers to global memory with optional elementwise operations. **Parameters:** * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type for the output matrix C. * ​c\_frag\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of each output fragment. * ​MMA\_M ([`Int`](/mojo/stdlib/builtin/int/Int)): Matrix multiply instruction M dimension. * ​MMA\_N ([`Int`](/mojo/stdlib/builtin/int/Int)): Matrix multiply instruction N dimension. * ​output\_thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Thread layout for output operations. * ​elementwise\_lambda\_fn ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional elementwise operation to apply. **Args:** * ​c\_reg\_fragment ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Register fragments containing computation results. * ​c\_gmem\_fragment ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Global memory fragment for output. * ​warp\_tile\_m ([`Int`](/mojo/stdlib/builtin/int/Int)): M coordinate of the warp tile. * ​warp\_tile\_n ([`Int`](/mojo/stdlib/builtin/int/Int)): N coordinate of the warp tile. * ​M ([`Int`](/mojo/stdlib/builtin/int/Int)): Total M dimension of the output matrix. * ​N ([`Int`](/mojo/stdlib/builtin/int/Int)): Total N dimension of the output matrix.
--- ## AMDPingPongMatmul
`struct AMDPingPongMatmul[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, config: KernelConfig, /, enable_l2_cache_optimization: Bool, enable_swizzle: Bool, use_transpose_load: Bool]` High-level ping-pong matmul implementation for AMD GPUs. This implements the 8-warp ping-pong pattern where warps alternate between loading data and computing, achieving overlapped execution. Memory Layout Strategy for Bank Conflict Avoidance: 1. Shared Memory Organization (AMD MI355 has 64 banks, 4 bytes each): * Uses double-buffered shared memory (ping-pong buffers) * Each buffer holds BM×BK elements for A, BN×BK for B 2. Bank Conflict Avoidance Pattern: * Bank index = (address / 4) % 64 * Swizzled access pattern distributes consecutive thread accesses across banks * Column swizzle: (lane\_id % 4) \* load\_width spreads within 32 bytes * Row stride: (lane\_id // 4) \* K ensures different rows map to different banks * Warp-level offsets further distribute accesses 3. Load Pattern (Global → Shared Memory): * Uses AMD's load\_to\_lds instruction for direct DRAM→LDS transfer * Bypasses L1/L2 caches for lower latency * Coalesced global memory access (consecutive threads → consecutive addresses) * Bank-conflict-free shared memory writes via swizzled offsets 4. MMA Access Pattern (Shared Memory → Registers): * Optimized for AMD's matrix cores (4 per CU on MI355) * 16×4 thread layout within each warp for MMA fragments * Ensures all 4 matrix cores stay busy throughout execution ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `accum_dtype` `comptime accum_dtype = get_accum_type[c_type]()` ### `accum_width` `comptime accum_width = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_M * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_N) // WARP_SIZE)` ### `BK` `comptime BK = config.block_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_shape.__getitem__[3, DType.int64, Int](1)` ### `LGKM_PER_LOAD_A` `comptime LGKM_PER_LOAD_A = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].quadrant_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_k_mmas)` ### `LGKM_PER_LOAD_AB` `comptime LGKM_PER_LOAD_AB = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].LGKM_PER_LOAD_A + AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].LGKM_PER_LOAD_B)` ### `LGKM_PER_LOAD_B` `comptime LGKM_PER_LOAD_B = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].quadrant_n_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_k_mmas)` ### `load_width` `comptime load_width = simd_width_of[a_type]()` ### `loading_threads_4warp` `comptime loading_threads_4warp = (4 * WARP_SIZE)` ### `loading_threads_8warp` `comptime loading_threads_8warp = (8 * WARP_SIZE)` ### `loads_per_row` `comptime loads_per_row = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].load_width)` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_accums` `comptime num_accums = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_m_mmas * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_n_mmas)` ### `num_k_mmas` `comptime num_k_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WK // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_K)` ### `num_m_mmas` `comptime num_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_M)` ### `num_n_mmas` `comptime num_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].MMA_N)` ### `num_warps_m` `comptime num_warps_m = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BM // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WM)` ### `num_warps_n` `comptime num_warps_n = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BN // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].WN)` ### `ping_pong_stages` `comptime ping_pong_stages = 2` ### `quadrant_m_mmas` `comptime quadrant_m_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_m_mmas // 2)` ### `quadrant_n_mmas` `comptime quadrant_n_mmas = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_n_mmas // 2)` ### `rows_per_iter_4warp` `comptime rows_per_iter_4warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].loading_threads_4warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].loads_per_row)` ### `rows_per_iter_8warp` `comptime rows_per_iter_8warp = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].loading_threads_8warp // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].loads_per_row)` ### `total_smem_a` `comptime total_smem_a = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BM) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BK)` ### `total_smem_b` `comptime total_smem_b = ((2 * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BN) * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BK)` ### `total_warps` `comptime total_warps = (AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_warps_m * AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].num_warps_n)` ### `VMCNT_PER_LOAD_A` `comptime VMCNT_PER_LOAD_A = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BM // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].rows_per_iter_8warp)` ### `VMCNT_PER_LOAD_A_4WARP` `comptime VMCNT_PER_LOAD_A_4WARP = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BM // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].rows_per_iter_4warp)` ### `VMCNT_PER_LOAD_B` `comptime VMCNT_PER_LOAD_B = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].rows_per_iter_8warp)` ### `VMCNT_PER_LOAD_B_4WARP` `comptime VMCNT_PER_LOAD_B_4WARP = ((AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].BN // 2) // AMDPingPongMatmul[a_type, b_type, c_type, a_layout, b_layout, c_layout, config, enable_l2_cache_optimization, enable_swizzle, use_transpose_load].rows_per_iter_4warp)` ### `WK` `comptime WK = config.warp_shape.__getitem__[3, DType.int64, Int](2)` ### `WM` `comptime WM = config.warp_shape.__getitem__[3, DType.int64, Int](0)` ### `WN` `comptime WN = config.warp_shape.__getitem__[3, DType.int64, Int](1)` ## Methods ### `validate_config` `static validate_config()` Validate the kernel configuration. ### `matmul_demo_ping_pong` `static matmul_demo_ping_pong(a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])`
--- ## KernelConfig
`struct KernelConfig` ## Fields * ​block\_shape (`IndexList[3]`): * ​warp\_shape (`IndexList[3]`): * ​mma\_shape (`IndexList[3]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, *, block_shape: IndexList[3], warp_shape: IndexList[3], mma_shape: IndexList[3])` ### `num_threads` `num_threads(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `write_to` `write_to(self, mut writer: T)` ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String)
--- ## MmaOp
`struct MmaOp[in_type: DType, accum_type: DType, WM: Int, WN: Int, BK: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, alignment: Int, enable_swizzle: Bool, swizzle_elem_base: Int, swizzle_shift: Int]` Encapsulates MMA register tiles and operations for matrix multiplication. This struct manages register tiles and MMA operations for a single warp. It processes warp-sized tiles (WM × BK for A, WN × BK for B) without knowledge of the broader kernel architecture. MmaOp accepts generic SMemTileType and validates compatibility at compile-time via load\_lds\_fragment constraints. Note: Several values are derived from other parameters: * num\_m\_mmas = WM // MMA\_M * num\_n\_mmas = WN // MMA\_N * num\_k\_mmas = BK // MMA\_K * load\_width = simd\_width\_of[in\_type]() (SIMD width for input type) * accum\_width = (MMA\_M \* MMA\_N) // WARP\_SIZE (elements per thread) Quadrant Processing: The warp tile is divided into 4 quadrants for MMA scheduling: * quadrant\_m\_mmas = num\_m\_mmas // 2 (M-dimension quadrant size) * quadrant\_n\_mmas = num\_n\_mmas // 2 (N-dimension quadrant size) This enables efficient interleaving of loads and computes. Thread Layout for MMA: AMD's expected pattern: 64 threads → 4 rows × 16 cols (row-major) Lane offset computed on-the-fly via lane\_id() Swizzle Configuration (enable\_swizzle=True): MmaOp receives swizzle parameters from the kernel/TileBuffers, since they are determined by how data is loaded into LDS. MmaOp must read using the same swizzle pattern that was used for writing. * swizzle\_elem\_base: bit position for XOR (from loading subtile width) * swizzle\_shift: XOR source distance (from loading subtile rows) ## Fields * ​a\_reg\_tile (`MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].ARegTileType`): * ​b\_reg\_tile (`MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].BRegTileType`): * ​out\_quadrants (`StaticTuple[StaticTuple[LayoutTensor[accum_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].accum_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment], 2], 2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `accum_width` `comptime accum_width = ((MMA_M * MMA_N) // WARP_SIZE)` ### `ARegTileType` `comptime ARegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]` ### `BRegTileType` `comptime BRegTileType = LayoutTensor[in_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_n_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]` ### `elem_swizzle` `comptime elem_swizzle = OptionalReg[Swizzle](Swizzle(1, swizzle_elem_base, swizzle_shift)) if enable_swizzle else OptionalReg[Swizzle]()` ### `lgkm_per_load_a` `comptime lgkm_per_load_a = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas)` ### `lgkm_per_load_ab` `comptime lgkm_per_load_ab = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].lgkm_per_load_a + MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].lgkm_per_load_b)` ### `lgkm_per_load_b` `comptime lgkm_per_load_b = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas)` ### `load_width` `comptime load_width = simd_width_of[in_type]()` ### `mma_access_layout` `comptime mma_access_layout = Layout(IntTuple(16, 4), IntTuple((4 * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width), MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width))` ### `num_k_mmas` `comptime num_k_mmas = (BK // MMA_K)` ### `num_m_mmas` `comptime num_m_mmas = (WM // MMA_M)` ### `num_n_mmas` `comptime num_n_mmas = (WN // MMA_N)` ### `OutQuadrantType` `comptime OutQuadrantType = LayoutTensor[accum_type, Layout.row_major(MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_m_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].quadrant_n_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].accum_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]` ### `quadrant_m_mmas` `comptime quadrant_m_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_m_mmas // 2)` ### `quadrant_n_mmas` `comptime quadrant_n_mmas = (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_n_mmas // 2)` ### `RegTileType` `comptime RegTileType[num_mmas: Int] = LayoutTensor[in_type, Layout.row_major(num_mmas, (MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].num_k_mmas * MmaOp[in_type, accum_type, WM, WN, BK, MMA_M, MMA_N, MMA_K, alignment, enable_swizzle, swizzle_elem_base, swizzle_shift].load_width)), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=alignment]` #### Parameters * ​num\_mmas ([`Int`](/mojo/stdlib/builtin/int/Int)): ## Methods ### `__init__` `__init__(out self)` Initialize MMA operation with register tiles. ### `reset_accumulator` `reset_accumulator(self)` Reset all output quadrants to zero. ### `load_a` `load_a[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load A\[which] from LDS → registers. Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load\_lds\_fragment constraints. ### `load_b` `load_b[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load B\[which] from LDS → registers. Accepts SMemTileType with matching dtype - layout compatibility validated at compile-time via load\_lds\_fragment constraints. ### `load_b_with_transpose` `load_b_with_transpose[which: Int](self, smem_tile: LayoutTensor[in_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load B\[which] from LDS → registers using hardware transpose. Uses ds\_read\_tr16\_b64 instruction for efficient transposed LDS read. This function expects B tiles in (N, K) storage order and produces data in the format expected by AMD MFMA instructions. Supports swizzle: When enable\_swizzle is True, applies the byte swizzle pattern to LDS read offsets for bank-conflict-free access. Requires: MMA shape must be 16x16x32 or 32x32x16 (double-rate MFMA). **Args:** * ​smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): B tile in LDS with shape (mma\_tile\_n, BK) = (N, K) order. ### `mma` `mma[which_a: Int, which_b: Int](self)` Execute MMA operations for a quadrant of the output tile. Each quadrant is stored in a separate contiguous register tile. **Parameters:** * ​which\_a ([`Int`](/mojo/stdlib/builtin/int/Int)): A quadrant index (0 or 1). * ​which\_b ([`Int`](/mojo/stdlib/builtin/int/Int)): B quadrant index (0 or 1).
--- ## TileBuffers
`struct TileBuffers[in_type: DType, a_layout: Layout, b_layout: Layout, //, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_threads: Int, alignment: Int, enable_swizzle: Bool, load_width: Int, loading_warps: Int = 8]` Double-buffered LDS tiles and TileLoaders for ping-pong matmul. a\_layout and b\_layout are infer-only parameters (note `//`), automatically extracted from the input tensors passed to **init**. K is derived as a comptime from a\_layout.shape\[1]. ## Fields * ​a\_mma\_tiles (`Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTilePair]`): * ​b\_mma\_tiles (`Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTilePair]`): * ​a\_load\_tiles (`Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTilePair]`): * ​b\_load\_tiles (`Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTilePair]`): * ​loader\_a (`TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].ATileLoader`): * ​loader\_b (`TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BTileLoader`): * ​warp\_id\_m (`Int`): * ​k\_offset (`Int`): * ​warp\_shift\_rows (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `AHalfTile` `comptime AHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` ### `AHalfTilePair` `comptime AHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AHalfTile]` ### `AMmaTile` `comptime AMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), False, alignment, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK), TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_m, BK](), alignment=alignment]` ### `AMmaTilePair` `comptime AMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].AMmaTile]` ### `ATileLoader` `comptime ATileLoader = TileLoaderLDS[in_type, a_layout, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width]` ### `BHalfTile` `comptime BHalfTile = LayoutTensor[in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` ### `BHalfTilePair` `comptime BHalfTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BHalfTile]` ### `BMmaTile` `comptime BMmaTile = LayoutTensor[in_type, LayoutTensor._compute_tile_layout[True, in_type, LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), False, alignment, WN, BK]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK](), alignment, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), linear_idx_type=_get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), masked=_tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() if _tile_is_masked[Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), WN, BK]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, in_type, Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), _get_index_type(Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BN, BK), AddressSpace.SHARED), False, alignment, WN, BK]()[0], TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].mma_tile_n, BK](), alignment=alignment]` ### `BMmaTilePair` `comptime BMmaTilePair = Tuple[TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].BMmaTile]` ### `BTileLoader` `comptime BTileLoader = TileLoaderLDS[in_type, b_layout, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width]` ### `byte_swizzle` `comptime byte_swizzle = OptionalReg[Swizzle](Swizzle(1, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_byte_base, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_shift)) if enable_swizzle else OptionalReg[Swizzle]()` ### `elem_size` `comptime elem_size = size_of[in_type]()` ### `elements_per_warp` `comptime elements_per_warp = (WARP_SIZE * load_width)` ### `half_BM` `comptime half_BM = (BM // 2)` ### `half_BN` `comptime half_BN = (BN // 2)` ### `half_tile_layout` `comptime half_tile_layout = Layout.row_major(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_BM, BK)` ### `HalfTile` `comptime HalfTile[rows: Int] = LayoutTensor[in_type, Layout.row_major(rows, BK), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` #### Parameters * ​rows ([`Int`](/mojo/stdlib/builtin/int/Int)): ### `K` `comptime K = a_layout.shape[1].value()` ### `loading_threads` `comptime loading_threads = (loading_warps * WARP_SIZE)` ### `loads_per_row` `comptime loads_per_row = (BK // load_width)` ### `mma_tile_m` `comptime mma_tile_m = (WM // 2)` ### `mma_tile_n` `comptime mma_tile_n = (WN // 2)` ### `rows_per_iter_4warp` `comptime rows_per_iter_4warp = ((4 * WARP_SIZE) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)` ### `rows_per_load_iteration` `comptime rows_per_load_iteration = (TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].loading_threads // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].loads_per_row)` ### `rows_per_warp` `comptime rows_per_warp = (TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].elements_per_warp // BK)` ### `smem_ptr` `comptime smem_ptr = LegacyUnsafePointer[Scalar[in_type], address_space=AddressSpace.SHARED]` ### `SMemTile` `comptime SMemTile[rows: Int, cols: Int] = LayoutTensor[in_type, Layout.row_major(rows, cols), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` #### Parameters * ​rows ([`Int`](/mojo/stdlib/builtin/int/Int)): * ​cols ([`Int`](/mojo/stdlib/builtin/int/Int)): ### `swizzle_byte_base` `comptime swizzle_byte_base = (TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_elem_base + log2_floor(TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].elem_size))` ### `swizzle_elem_base` `comptime swizzle_elem_base = log2_floor((TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].swizzle_subtile_cols // 2))` ### `swizzle_shift` `comptime swizzle_shift = log2_floor(16)` ### `swizzle_subtile_cols` `comptime swizzle_subtile_cols = (4 * load_width)` ### `swizzle_subtile_rows` `comptime swizzle_subtile_rows = 16` ### `TileLoader` `comptime TileLoader[src_layout: Layout] = TileLoaderLDS[in_type, src_layout, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].half_tile_layout, loading_warps, TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].byte_swizzle, load_width]` #### Parameters * ​src\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): ### `total_warps` `comptime total_warps = 8` ### `vmcnt_per_load_a` `comptime vmcnt_per_load_a = ((BM // 2) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)` ### `vmcnt_per_load_a_4warp` `comptime vmcnt_per_load_a_4warp = ((BM // 2) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_iter_4warp)` ### `vmcnt_per_load_ab` `comptime vmcnt_per_load_ab = (TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_a + TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].vmcnt_per_load_b)` ### `vmcnt_per_load_b` `comptime vmcnt_per_load_b = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_load_iteration)` ### `vmcnt_per_load_b_4warp` `comptime vmcnt_per_load_b_4warp = ((BN // 2) // TileBuffers[BM, BN, BK, WM, WN, num_threads, alignment, enable_swizzle, load_width, loading_warps].rows_per_iter_4warp)` ## Methods ### `__init__` `__init__(out self, a: LayoutTensor[in_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], block_row: Int, block_col: Int, warp_id: Int, warp_id_m: Int, warp_id_n: Int, lane_id: Int)` Initialize LDS tiles and loaders. Layouts inferred from a and b tensors. ### `advance_k` `advance_k(mut self)` Advance k\_offset by BK for the next K iteration. ### `load_a` `load_a[stage: Int, which: Int](self)` Load A\[stage]\[which] using 8 warps. ### `load_b` `load_b[stage: Int, which: Int](self)` Load B\[stage]\[which] using 8 warps. ### `load_a_as_group` `load_a_as_group[stage: Int, target_group: Int](self, caller_group: Int)` Load A\[stage]\[target\_group] using 4 warps. Only executes if caller\_group == target\_group. ### `load_b_as_group` `load_b_as_group[stage: Int, which: Int](self, caller_group: Int, loading_group: Int)` Load B\[stage]\[which] using 4 warps. Only executes if caller\_group == loading\_group.
--- ## TileLoaderLDS
`@register_passable(trivial)` `struct TileLoaderLDS[dtype: DType, src_layout: Layout, src_tile_layout: Layout, num_loading_warps: Int, swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle](), load_width: Int = simd_width_of[dtype]()]` Encapsulates load\_to\_lds with pre-computed thread positions and swizzle. ## Fields * ​buffer (`AMDBufferResource`): * ​thread\_row (`Int`): * ​thread\_col (`Int`): * ​warp\_id (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `elements_per_warp` `comptime elements_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].threads_per_warp * load_width)` ### `loading_threads` `comptime loading_threads = (num_loading_warps * TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].threads_per_warp)` ### `loads_per_row` `comptime loads_per_row = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].tile_cols // load_width)` ### `num_iterations` `comptime num_iterations = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].tile_rows // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].rows_per_iteration)` ### `rows_per_iteration` `comptime rows_per_iteration = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].loading_threads // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].loads_per_row)` ### `rows_per_warp` `comptime rows_per_warp = (TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].elements_per_warp // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].tile_cols)` ### `stride` `comptime stride = src_layout.shape[1].value()` ### `subtile_cols` `comptime subtile_cols = 32` ### `thread_rows` `comptime thread_rows = (WARP_SIZE // TileLoaderLDS[dtype, src_layout, src_tile_layout, num_loading_warps, swizzle, load_width].threads_per_row)` ### `threads_per_row` `comptime threads_per_row = (32 // load_width)` ### `threads_per_warp` `comptime threads_per_warp = WARP_SIZE` ### `tile_cols` `comptime tile_cols = src_tile_layout.shape[1].value()` ### `tile_rows` `comptime tile_rows = src_tile_layout.shape[0].value()` ## Methods ### `__init__` `__init__(src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_id: Int, lane_id: Int) -> Self` Pre-compute thread position with swizzle inversion for bank-conflict-free reads. ### `load_tile` `load_tile[dst_layout: Layout, //](self, dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_row: Int, src_col: Int)` Load a tile from source coordinates to LDS. Combines pre-computed thread position with source coordinates. Uses the buffer resource stored at init time. Only warps 0 to num\_loading\_warps-1 participate; others return immediately. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination LDS tile. * ​src\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row in source tensor. * ​src\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column in source tensor (typically k\_offset).
--- ## chiplet_transform_chunked
`chiplet_transform_chunked[num_xcds: Int, chunk_size: Int](workgroup_id: Int, num_workgroups: Int) -> Int` Transform work group ID for better chiplet locality. AMD MI300X/MI355X have 8 XCDs (chiplets), each with its own L2 cache. This function reorganizes blocks from round-robin distribution to chunked allocation, improving cache locality. Original pattern: WG0→XCD0, WG1→XCD1, ..., WG8→XCD0 Transformed: WG0-63→XCD0, WG64-127→XCD1, etc. **Parameters:** * ​num\_xcds ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of XCDs (8 for MI300X/MI355X). * ​chunk\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of blocks per XCD chunk. **Args:** * ​workgroup\_id ([`Int`](/mojo/stdlib/builtin/int/Int)): Original block ID. * ​num\_workgroups ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of blocks. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Transformed block ID for better XCD locality.
--- ## pingpong_kernel
## Structs * [​`AMDPingPongMatmul`](./AMDPingPongMatmul): High-level ping-pong matmul implementation for AMD GPUs. * [​`KernelConfig`](./KernelConfig): * [​`MmaOp`](./MmaOp): Encapsulates MMA register tiles and operations for matrix multiplication. * [​`TileBuffers`](./TileBuffers): Double-buffered LDS tiles and TileLoaders for ping-pong matmul. * [​`TileLoaderLDS`](./TileLoaderLDS): Encapsulates load\_to\_lds with pre-computed thread positions and swizzle. ## Functions * [​`chiplet_transform_chunked`](./chiplet_transform_chunked): Transform work group ID for better chiplet locality. * [​`load_lds_fragment`](./load_lds_fragment): Load LDS → registers with MMA access pattern. * [​`ping_pong_matmul`](./ping_pong_matmul):
--- ## load_lds_fragment
`load_lds_fragment[dtype: DType, smem_layout: Layout, smem_element_layout: Layout, frag_layout: Layout, frag_element_layout: Layout, //, mma_access_layout: Layout, swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]()](smem_tile: LayoutTensor[dtype, smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=smem_element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], reg_frag: LayoutTensor[dtype, frag_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=frag_element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load LDS → registers with MMA access pattern. Why mma\_access\_layout differs from the global→LDS thread layout: ┌─────────────────────────────────────────────────────────────────────┐ │ Layout │ Purpose │ Constraint │ ├─────────────────────────────────────────────────────────────────────┤ │ load\_thread │ Global → LDS write │ Coalesced global reads │ │ mma\_access │ LDS → Registers read │ AMD WMMA hardware pattern │ └─────────────────────────────────────────────────────────────────────┘ mma\_access\_layout encodes how AMD's WMMA instruction expects data: * Lane decomposition: (lane % 16, lane // 16) = (col\_group, row\_group) * Offset computation: col\_group \* 32 + row\_group \* 8 Using RuntimeLayout ensures compile-time evaluation (no GPU heap alloc). Layout compatibility requirements: * mma\_access\_layout must map exactly WARP\_SIZE (64) threads * smem must have enough elements for: num\_iterations \* WARP\_SIZE \* frag\_width * frag must store: num\_iterations \* frag\_width elements
--- ## ping_pong_matmul
`ping_pong_matmul[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, //, enable_l2_cache_optimization: Bool = False, enable_swizzle: Bool = True, use_transpose_load: Bool = False](a_device_tensor: LayoutTensor[a_type, a_layout, origin], b_device_tensor: LayoutTensor[b_type, b_layout, origin], c_device_tensor: LayoutTensor[c_type, c_layout, origin], ctx: DeviceContext)`
--- ## ConsumerTile
`@register_passable(trivial)` `struct ConsumerTile[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type]], warps_computed_per_consumer: Int]` Context manager for consumer access to a single ring buffer tile. ## Fields * ​consumer\_view\_ptr (`ConsumerTile[origin, ring_buffer_type, warps_computed_per_consumer].ConsumerViewPtrType`): * ​stage (`Int`): * ​consumer\_iteration (`Int`): * ​warp\_tile\_idx (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ConsumerViewPtrType` `comptime ConsumerViewPtrType = Pointer[ConsumerTile[origin, ring_buffer_type, warps_computed_per_consumer].ConsumerViewType, origin]` ### `ConsumerViewType` `comptime ConsumerViewType = ConsumerView[origin, ring_buffer_type, warps_computed_per_consumer]` ## Methods ### `__init__` `__init__(consumer_view_ptr: Pointer[ConsumerTile[origin, ring_buffer_type, warps_computed_per_consumer].ConsumerViewType, origin], stage: Int, consumer_iteration: Int, warp_tile_idx: Int) -> Self` ### `__enter__` `__enter__(mut self) -> ring_buffer_type.WarpTileTupleType` Acquire the tile for use. **Returns:** `ring_buffer_type.WarpTileTupleType` ### `__exit__` `__exit__(mut self)` Release the tile back to producers.
--- ## ConsumerView
`@register_passable(trivial)` `struct ConsumerView[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type]], warps_computed_per_consumer: Int]` Consumer view of the unified ring buffer. ## Fields * ​ring\_buffer\_ptr (`ConsumerView[origin, ring_buffer_type, warps_computed_per_consumer].RingBufferPtrType`): * ​phases (`StaticTuple[Int32, (pipeline_stages * warps_computed_per_consumer)]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ConsumerTileType` `comptime ConsumerTileType = ConsumerTile[origin, ring_buffer_type, warps_computed_per_consumer]` ### `RingBufferPtrType` `comptime RingBufferPtrType = Pointer[ring_buffer_type, origin]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[ring_buffer_type, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` Context manager entry. ### `__exit__` `__exit__(mut self)` Context manager exit. ### `acquire_tiles` `acquire_tiles(mut self, stage: Int, consumer_iteration: Int, warp_tile_idx: Int) -> ring_buffer_type.WarpTileTupleType` Acquire tiles for reading by this consumer. **Args:** * ​stage ([`Int`](/mojo/stdlib/builtin/int/Int)): Pipeline stage to read from. * ​consumer\_iteration ([`Int`](/mojo/stdlib/builtin/int/Int)): Which iteration this consumer is on (0 to warps\_computed\_per\_consumer-1). * ​warp\_tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Which tile this consumer wants to read. **Returns:** `ring_buffer_type.WarpTileTupleType` ### `release_tiles` `release_tiles(mut self, stage: Int, warp_tile_idx: Int)` Signal to producers that tile is free. ### `get_tile` `get_tile(mut self, stage: Int, consumer_iteration: Int, warp_tile_idx: Int) -> ConsumerView[origin, ring_buffer_type, warps_computed_per_consumer].ConsumerTileType` Get a context manager for accessing a tile. **Args:** * ​stage ([`Int`](/mojo/stdlib/builtin/int/Int)): Pipeline stage. * ​consumer\_iteration ([`Int`](/mojo/stdlib/builtin/int/Int)): Current iteration of this consumer. * ​warp\_tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Which tile to access. **Returns:** `ConsumerView`
--- ## ProducerTile
`@register_passable(trivial)` `struct ProducerTile[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type]], warps_processed_per_producer: Int]` Context manager for producer access to a single ring buffer tile. ## Fields * ​producer\_view\_ptr (`ProducerTile[origin, ring_buffer_type, warps_processed_per_producer].ProducerViewPtrType`): * ​stage (`Int`): * ​producer\_iteration (`Int`): * ​warp\_tile\_idx (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ProducerViewPtrType` `comptime ProducerViewPtrType = Pointer[ProducerTile[origin, ring_buffer_type, warps_processed_per_producer].ProducerViewType, origin]` ### `ProducerViewType` `comptime ProducerViewType = ProducerView[origin, ring_buffer_type, warps_processed_per_producer]` ## Methods ### `__init__` `__init__(producer_view_ptr: Pointer[ProducerTile[origin, ring_buffer_type, warps_processed_per_producer].ProducerViewType, origin], stage: Int, producer_iteration: Int, warp_tile_idx: Int) -> Self` ### `__enter__` `__enter__(mut self) -> ring_buffer_type.WarpTileTupleType` Acquire the tile for use. **Returns:** `ring_buffer_type.WarpTileTupleType` ### `__exit__` `__exit__(mut self)` Release the tile back to consumers.
--- ## ProducerView
`@register_passable(trivial)` `struct ProducerView[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type]], warps_processed_per_producer: Int]` Producer view of the unified ring buffer. ## Fields * ​ring\_buffer\_ptr (`ProducerView[origin, ring_buffer_type, warps_processed_per_producer].RingBufferPtrType`): * ​phases (`StaticTuple[Int32, (pipeline_stages * warps_processed_per_producer)]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ProducerTileType` `comptime ProducerTileType = ProducerTile[origin, ring_buffer_type, warps_processed_per_producer]` ### `RingBufferPtrType` `comptime RingBufferPtrType = Pointer[ring_buffer_type, origin]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[ring_buffer_type, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` Context manager entry. ### `__exit__` `__exit__(mut self)` Context manager exit. ### `acquire_tiles` `acquire_tiles(mut self, stage: Int, producer_iteration: Int, warp_tile_idx: Int) -> ring_buffer_type.WarpTileTupleType` Acquire tiles for writing by this producer. **Args:** * ​stage ([`Int`](/mojo/stdlib/builtin/int/Int)): Pipeline stage to write to. * ​producer\_iteration ([`Int`](/mojo/stdlib/builtin/int/Int)): Which iteration this producer is on (`0` to `warps_processed_per_producer - 1`). * ​warp\_tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Which tile this producer is responsible for. **Returns:** `ring_buffer_type.WarpTileTupleType` ### `release_tiles` `release_tiles(mut self, stage: Int, warp_tile_idx: Int)` Signal to consumers that tile is ready. ### `get_tile` `get_tile(mut self, stage: Int, warp_tile_idx: Int, producer_iteration: Int) -> ProducerView[origin, ring_buffer_type, warps_processed_per_producer].ProducerTileType` Get a context manager for accessing a tile. **Args:** * ​stage ([`Int`](/mojo/stdlib/builtin/int/Int)): Pipeline stage. * ​warp\_tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Which tile to access. * ​producer\_iteration ([`Int`](/mojo/stdlib/builtin/int/Int)): Current iteration of this producer. **Returns:** `ProducerView`
--- ## RingBuffer
`struct RingBuffer[dtype: DType, layout: Layout, pipeline_stages: Int, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, reads_per_warp_block: Int, tile_buffers: Int, sync_strategy_type: SyncStrategy]` Ring buffer for coordinating producer-consumer warps in matrix multiplication. ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of elements. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Memory layout for shared memory tiles. * ​pipeline\_stages ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of stages for software pipelining. * ​block\_rows ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in block-level tiles. * ​block\_cols ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in block-level tiles. * ​warp\_rows ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in warp-level tiles. * ​warp\_cols ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in warp-level tiles. * ​reads\_per\_warp\_block ([`Int`](/mojo/stdlib/builtin/int/Int)): How many consumer warps read each tile. * ​tile\_buffers ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of separate tile buffers (usually 1). * ​sync\_strategy\_type ([`SyncStrategy`](/mojo/kernels/linalg/matmul/gpu/amd/ring_buffer_traits/SyncStrategy)): Synchronization strategy (SingleCounterSync or SplitCounterSync). ## Fields * ​smem\_buffers (`RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SMemBuffersType`): * ​sync\_strategy (`sync_strategy_type`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = sync_strategy_type.__del__is_trivial` ### `block_warps` `comptime block_warps = (block_rows // warp_rows)` ### `SMemBuffersType` `comptime SMemBuffersType = StaticTuple[SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols], tile_buffers]` ### `SmemBufferType` `comptime SmemBufferType = SMemBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols]` ### `total_tiles` `comptime total_tiles = (RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].block_warps * pipeline_stages)` ### `WarpTileTupleType` `comptime WarpTileTupleType = StaticTuple[LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, block_rows, block_cols]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols](), 128, warp_rows, warp_cols]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), linear_idx_type=_get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), masked=_tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols]() if _tile_is_masked[pipeline_layout[layout, pipeline_stages](), block_rows, block_cols]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, block_rows, block_cols]()[0], warp_rows, warp_cols](), alignment=128], tile_buffers]` ### `WarpTileType` `comptime WarpTileType = RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].SmemBufferType.WarpTileType` ## Methods ### `__init__` `__init__(out self)` ### `get_tiles` `get_tiles(self, stage: Int, warp_tile_idx: Int) -> RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type].WarpTileTupleType` Get tiles from shared memory. **Returns:** `RingBuffer` ### `producer` `producer[warps_processed_per_producer: Int](mut self) -> ProducerView[self, RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_processed_per_producer]` Create a producer view of this ring buffer. **Returns:** `ProducerView` ### `consumer` `consumer[warps_computed_per_consumer: Int](mut self) -> ConsumerView[self, RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warps_computed_per_consumer]` Create a consumer view of this ring buffer. **Returns:** `ConsumerView` ### `get_staged_idx` `get_staged_idx(self, tile_idx: Int, stage: Int) -> Int` Get the staged index for a tile and stage. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `wait_producer_acquire` `wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` Producer waits to acquire a tile. ### `signal_producer_release` `signal_producer_release(mut self, tile_idx: Int, stage: Int)` Producer signals it has released a tile. ### `wait_consumer_acquire` `wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` Consumer waits to acquire a tile. ### `signal_consumer_release` `signal_consumer_release(mut self, tile_idx: Int, stage: Int)` Consumer signals it has released a tile. ### `get_producer_phase_increment` `get_producer_phase_increment(self) -> Int32` Get the phase increment for producers. **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32) ### `get_consumer_phase_increment` `get_consumer_phase_increment(self) -> Int32` Get the phase increment for consumers. **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32)
--- ## ring_buffer
Ring Buffer implementation for producer-consumer synchronization in GPU kernels. This ring buffer coordinates data transfer between producer warps (loading from global memory) and consumer warps (performing computation) through shared memory tiles. Key features: * Configurable synchronization strategies via the SyncStrategy trait * Pipeline stages for overlapping data transfer and computation * Context managers for automatic acquire/release of tiles * Phase-based synchronization to prevent data races ## Structs * [​`ConsumerTile`](./ConsumerTile): Context manager for consumer access to a single ring buffer tile. * [​`ConsumerView`](./ConsumerView): Consumer view of the unified ring buffer. * [​`ProducerTile`](./ProducerTile): Context manager for producer access to a single ring buffer tile. * [​`ProducerView`](./ProducerView): Producer view of the unified ring buffer. * [​`RingBuffer`](./RingBuffer): Ring buffer for coordinating producer-consumer warps in matrix multiplication.
--- ## SingleCounterSync
`@register_passable(trivial)` `struct SingleCounterSync[pipeline_stages: Int, block_rows: Int, warp_rows: Int, reads_per_warp_block: Int]` Single counter synchronization strategy. Uses one atomic counter per tile that tracks both producer and consumer progress. This is simpler but has higher contention as all warps compete for the same counter. Phase progression: * Each phase advances by (writes\_per\_warp\_block + reads\_per\_warp\_block) * Producers wait for phase N, increment counter by 1 * Consumers wait for phase N+1, increment counter by 1 ## Fields * ​sync\_counter (`SingleCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].SyncCounterArray`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`SyncStrategy`](/mojo/kernels/linalg/matmul/gpu/amd/ring_buffer_traits/SyncStrategy), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `block_warps` `comptime block_warps = (block_rows // warp_rows)` ### `SyncCounterArray` `comptime SyncCounterArray = SMemArrayType[Int32, SingleCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].total_tiles]` ### `total_tiles` `comptime total_tiles = (SingleCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].block_warps * pipeline_stages)` ### `writes_per_warp_block` `comptime writes_per_warp_block = 1` ## Methods ### `__init__` `__init__() -> Self` Initialize with internally allocated sync counter. ### `get_staged_idx` `get_staged_idx(self, tile_idx: Int, stage: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `wait_producer_acquire` `wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` ### `signal_producer_release` `signal_producer_release(mut self, tile_idx: Int, stage: Int)` ### `wait_consumer_acquire` `wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` ### `signal_consumer_release` `signal_consumer_release(mut self, tile_idx: Int, stage: Int)` ### `get_producer_phase_increment` `get_producer_phase_increment(self) -> Int32` **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32) ### `get_consumer_phase_increment` `get_consumer_phase_increment(self) -> Int32` **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32)
--- ## SplitCounterSync
`@register_passable(trivial)` `struct SplitCounterSync[pipeline_stages: Int, block_rows: Int, warp_rows: Int, reads_per_warp_block: Int]` Split counter synchronization strategy. Uses separate producer and consumer counters per tile to reduce atomic contention. Producers only write to producer counters, consumers only write to consumer counters. Phase progression: * Producer phase advances by reads\_per\_warp\_block (waits for N consumers) * Consumer phase advances by writes\_per\_warp\_block (waits for 1 producer) * This asymmetry reflects the 1-producer-to-N-consumers relationship ## Fields * ​producer\_counters (`SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].ProducerCounterArray`): * ​consumer\_counters (`SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].ConsumerCounterArray`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`SyncStrategy`](/mojo/kernels/linalg/matmul/gpu/amd/ring_buffer_traits/SyncStrategy), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `block_warps` `comptime block_warps = (block_rows // warp_rows)` ### `ConsumerCounterArray` `comptime ConsumerCounterArray = SMemArrayType[Int32, SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].total_tiles]` ### `ProducerCounterArray` `comptime ProducerCounterArray = SMemArrayType[Int32, SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].total_tiles]` ### `total_tiles` `comptime total_tiles = (SplitCounterSync[pipeline_stages, block_rows, warp_rows, reads_per_warp_block].block_warps * pipeline_stages)` ### `writes_per_warp_block` `comptime writes_per_warp_block = 1` ## Methods ### `__init__` `__init__() -> Self` Initialize with internally allocated producer and consumer counters. ### `get_staged_idx` `get_staged_idx(self, tile_idx: Int, stage: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `wait_producer_acquire` `wait_producer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` Producer waits on consumer counter. ### `signal_producer_release` `signal_producer_release(mut self, tile_idx: Int, stage: Int)` Producer increments producer counter. ### `wait_consumer_acquire` `wait_consumer_acquire(self, tile_idx: Int, stage: Int, phase: Int32)` Consumer waits on producer counter. ### `signal_consumer_release` `signal_consumer_release(mut self, tile_idx: Int, stage: Int)` Consumer increments consumer counter by 1. ### `get_producer_phase_increment` `get_producer_phase_increment(self) -> Int32` Producer phase advances by reads\_per\_warp\_block. **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32) ### `get_consumer_phase_increment` `get_consumer_phase_increment(self) -> Int32` Consumer phase advances by writes\_per\_warp\_block. **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32)
--- ## SyncStrategy
Interface for synchronization strategies between producers and consumers. All methods have the same signature regardless of the specific implementation, allowing the RingBuffer to be parameterized with any conforming strategy. Phase tracking ensures producers and consumers access different tiles: * Producers wait until consumers have finished with a tile (phase N) * Consumers wait until producers have filled a tile (phase N+1) ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__init__` `__init__() -> _Self` Initialize with internally allocated sync counter. **Returns:** `_Self` ### `get_staged_idx` `get_staged_idx(self: _Self, tile_idx: Int, stage: Int) -> Int` Convert tile index and stage to a flat index in the counter arrays. **Args:** * ​tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Index of the tile within a stage (0 to block\_warps-1). * ​stage ([`Int`](/mojo/stdlib/builtin/int/Int)): Pipeline stage (0 to pipeline\_stages-1). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Flat index for accessing synchronization counters. ### `wait_producer_acquire` `wait_producer_acquire(self: _Self, tile_idx: Int, stage: Int, phase: Int32)` Producer waits until it can write to the specified tile. Blocks until all consumers have finished reading from this tile (counter >= phase). ### `signal_producer_release` `signal_producer_release(mut self: _Self, tile_idx: Int, stage: Int)` Producer signals that it has finished writing to the tile. Increments the appropriate counter to notify waiting consumers. ### `wait_consumer_acquire` `wait_consumer_acquire(self: _Self, tile_idx: Int, stage: Int, phase: Int32)` Consumer waits until it can read from the specified tile. Blocks until producer has finished writing to this tile (counter >= phase). ### `signal_consumer_release` `signal_consumer_release(mut self: _Self, tile_idx: Int, stage: Int)` Consumer signals that it has finished reading from the tile. Increments the appropriate counter to notify waiting producers. ### `get_producer_phase_increment` `get_producer_phase_increment(self: _Self) -> Int32` Returns how much to advance the producer phase after each acquisition. This determines when producers can reuse a tile after consumers finish. **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32) ### `get_consumer_phase_increment` `get_consumer_phase_increment(self: _Self) -> Int32` Returns how much to advance the consumer phase after each acquisition. This determines when consumers can read a tile after producers finish. **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32)
--- ## increment_counter_if_first_thread
`increment_counter_if_first_thread(counter: UnsafePointer[Int32, origin, address_space=AddressSpace.SHARED], increment: Int32)` Atomically increment counter, but only from the first thread in warp.
--- ## ring_buffer_traits
Trait definitions and utilities for ring buffer synchronization strategies. This module provides: * SyncStrategy trait: Interface for producer-consumer synchronization protocols * SingleCounterSync: Uses a single atomic counter per tile (original RingBuffer behavior) * SplitCounterSync: Uses separate producer/consumer counters to reduce contention * Atomic utility functions for thread-safe counter operations ## Structs * [​`SingleCounterSync`](./SingleCounterSync): Single counter synchronization strategy. * [​`SplitCounterSync`](./SplitCounterSync): Split counter synchronization strategy. ## Traits * [​`SyncStrategy`](./SyncStrategy): Interface for synchronization strategies between producers and consumers. ## Functions * [​`increment_counter_if_first_thread`](./increment_counter_if_first_thread): Atomically increment counter, but only from the first thread in warp. * [​`wait_for_counter`](./wait_for_counter): Spin-wait until counter reaches threshold.
--- ## wait_for_counter
`wait_for_counter(counter: UnsafePointer[Int32, origin, address_space=AddressSpace.SHARED], threshold: Int32)` Spin-wait until counter reaches threshold.
--- ## AMDSharedMemoryBarrier
`@register_passable(trivial)` `struct AMDSharedMemoryBarrier` ## Fields * ​\_\_repr (`Int32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `initialize` `initialize(ref [MutAnyOrigin, 3] self)` ### `value` `value(ref [3] self) -> Int32` **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32) ### `increment` `increment(ref [MutAnyOrigin, 3] self, warp_id: Int)` ### `wait_until_greater_or_equal_to` `wait_until_greater_or_equal_to(ref [3] self, v: Int32)`
--- ## AMDWarpSharedMemoryBarrier
`@register_passable(trivial)` `struct AMDWarpSharedMemoryBarrier[size: Int]` ## Fields * ​\_\_repr (`StaticTuple[Int32, size]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `initialize` `initialize(ref [MutAnyOrigin, 3] self)` ### `value` `value(ref [3] self) -> Int32` **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32) ### `increment` `increment(ref [MutAnyOrigin, 3] self, warp_id: Int)` ### `wait_until_greater_or_equal_to` `wait_until_greater_or_equal_to(ref [3] self, v: Int32)`
--- ## AmdTileOperator
`@register_passable(trivial)` `struct AmdTileOperator[InType: DType, OutType: DType, warp_block_layout_a: Layout, warp_block_layout_b: Layout, mma_shape: IndexList[3], swizzle: OptionalReg[Swizzle] = None, transpose_b: Bool = True]` Manages tensor core operations for matrix multiplication on AMD GPUs. This operator handles loading matrix fragments from shared memory to registers and performing matrix multiply-accumulate operations using tensor cores. Requirements: \- warp\_block\_layout\_a.shape\[0] must be divisible by mma\_shape\[0] \- warp\_block\_layout\_b.shape\[0] must be divisible by mma\_shape\[1] \- warp\_block\_layout\_a.shape\[1] must be divisible by mma\_shape\[2] \- warp\_block\_layout\_b.shape\[1] must be divisible by mma\_shape\[2] \- The K dimension must align such that num\_k\_tiles is divisible by k\_group\_size ## Parameters * ​InType ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Input data type. * ​OutType ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Output data type. * ​warp\_block\_layout\_a ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout for matrix A warp tiles. * ​warp\_block\_layout\_b ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout for matrix B warp tiles. * ​mma\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Shape of the MMA operation \[M, N, K]. * ​swizzle ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional swizzle pattern for memory access. * ​transpose\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether matrix B is transposed. ## Fields * ​out\_reg\_tile (`AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].OutRegTileType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ARegTileType` `comptime ARegTileType = LayoutTensor[InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `BRegTileType` `comptime BRegTileType = LayoutTensor[InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `k_group_size_a` `comptime k_group_size_a = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]())` ### `k_group_size_b` `comptime k_group_size_b = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width // num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]())` ### `k_tile_fragment_index` `comptime k_tile_fragment_index[k_tile_idx: Int] = (k_tile_idx % AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)` #### Parameters * ​k\_tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): ### `k_tile_group_index` `comptime k_tile_group_index[k_tile_idx: Int] = (k_tile_idx // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a)` #### Parameters * ​k\_tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): ### `num_k_tiles` `comptime num_k_tiles = (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].WK // mma_shape.__getitem__[3, DType.int64, Int](2))` ### `num_m_mmas` `comptime num_m_mmas = (product(warp_block_layout_a.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](0))` ### `num_n_mmas` `comptime num_n_mmas = (product(warp_block_layout_b.shape[0]) // mma_shape.__getitem__[3, DType.int64, Int](1))` ### `out_frag_size` `comptime out_frag_size = ((mma_shape.__getitem__[3, DType.int64, Int](0) * mma_shape.__getitem__[3, DType.int64, Int](1)) // WARP_SIZE)` ### `OutRegTileFragmentType` `comptime OutRegTileFragmentType = LayoutTensor[OutType, LayoutTensor._compute_tile_layout[True, OutType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), AddressSpace.LOCAL), _get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), AddressSpace.LOCAL), False, align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]](), (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), (AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()](), alignment=align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]]()]` ### `OutRegTileType` `comptime OutRegTileType = LayoutTensor[OutType, Layout.row_major((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](1)]()), MutAnyOrigin, address_space=AddressSpace.LOCAL, alignment=align_of[SIMD[InType, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]]()]` ### `simd_width` `comptime simd_width = simd_width_of[InType]()` ### `tensor_core` `comptime tensor_core = TensorCore[OutType, InType, mma_shape, transpose_b]()` ### `total_k_tiles` `comptime total_k_tiles = AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles` ### `WK` `comptime WK = product(warp_block_layout_a.shape[1])` ## Methods ### `__init__` `__init__() -> Self` ### `a_reg_tile` `a_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[True, InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), False, align_of[InType](), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_a) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_m_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]` Get A register tile for a specific K tile. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `b_reg_tile` `b_reg_tile(self, k_tile_idx: Int) -> LayoutTensor[InType, LayoutTensor._compute_tile_layout[True, InType, Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), False, align_of[InType](), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, layout_int_type=_get_layout_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), linear_idx_type=_get_index_type(Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AddressSpace.LOCAL), masked=_tile_is_masked[Layout.row_major(((AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_k_tiles // AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].k_group_size_b) * AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width), AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].num_n_mmas, AmdTileOperator[InType, OutType, warp_block_layout_a, warp_block_layout_b, mma_shape, swizzle, transpose_b].simd_width]()]` Get B register tile for a specific K tile. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `reset_accumulator` `reset_accumulator(self)` Reset the accumulator to zero for a new tile computation. ### `load_tile_fragment` `load_tile_fragment[k_tile_idx: Int](self, smem_tile_a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], smem_tile_b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Load fragments from shared memory to registers for a specific K tile. **Parameters:** * ​k\_tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): K-tile index (0 to total\_k\_tiles-1). **Args:** * ​smem\_tile\_a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Shared memory tile for matrix A. * ​smem\_tile\_b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Shared memory tile for matrix B. ### `mma_compute` `mma_compute[k_tile_idx: Int](self)` Perform matrix multiply-accumulate for a specific K tile. This method assumes fragments are already loaded via load\_tile\_fragment. **Parameters:** * ​k\_tile\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): K-tile index (0 to total\_k\_tiles-1).
--- ## Enum
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `value` `value(self: _Self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ## Provided methods ### `__eq__` `__eq__(self: _Self, other: _Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self: _Self, other: _Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__is__` `__is__(self: _Self, other: _Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__isnot__` `__isnot__(self: _Self, other: _Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## MMAConfig
`@register_passable(trivial)` `struct MMAConfig[InType: DType, OutType: DType, mma_shape: IndexList[3], transpose_b: Bool = True]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `k_group_size_a` `comptime k_group_size_a = (MMAConfig[InType, OutType, mma_shape, transpose_b].simd_width // MMAConfig[InType, OutType, mma_shape, transpose_b].registers_per_thread_a)` ### `k_group_size_b` `comptime k_group_size_b = (MMAConfig[InType, OutType, mma_shape, transpose_b].simd_width // MMAConfig[InType, OutType, mma_shape, transpose_b].registers_per_thread_b)` ### `mma` `comptime mma = TensorCore[OutType, InType, mma_shape, transpose_b]()` ### `registers_per_thread_a` `comptime registers_per_thread_a = num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](0), mma_shape.__getitem__[3, DType.int64, Int](2)]()` ### `registers_per_thread_b` `comptime registers_per_thread_b = num_matrix_reg[mma_shape.__getitem__[3, DType.int64, Int](1), mma_shape.__getitem__[3, DType.int64, Int](2)]()` ### `simd_width` `comptime simd_width = simd_width_of[InType]()` ## Methods ### `adjusted_mma_k_shape_a` `static adjusted_mma_k_shape_a() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `adjusted_mma_k_shape_b` `static adjusted_mma_k_shape_b() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## SMemBuffer
`@register_passable(trivial)` `struct SMemBuffer[dtype: DType, layout: Layout, pipeline_stages: Int, BM: Int, BN: Int, WM: Int, WN: Int]` Manages shared memory and returns 2D tile slices of the buffer. ## Fields * ​buffer (`SMemBuffer[dtype, layout, pipeline_stages, BM, BN, WM, WN].SMemTileType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BlockTileType` `comptime BlockTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, BM, BN]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), linear_idx_type=_get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), masked=_tile_is_masked[pipeline_layout[layout, pipeline_stages](), BM, BN](), alignment=128]` ### `SMemTileType` `comptime SMemTileType = LayoutTensor[dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]` ### `WarpTileType` `comptime WarpTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, BM, BN]()[0], MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _tile_is_masked[pipeline_layout[layout, pipeline_stages](), BM, BN](), 128, WM, WN]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), linear_idx_type=_get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), masked=_tile_is_masked[pipeline_layout[layout, pipeline_stages](), BM, BN]() if _tile_is_masked[pipeline_layout[layout, pipeline_stages](), BM, BN]() else _tile_is_masked[LayoutTensor._compute_tile_layout[True, dtype, pipeline_layout[layout, pipeline_stages](), MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), _get_index_type(pipeline_layout[layout, pipeline_stages](), AddressSpace.SHARED), False, 128, BM, BN]()[0], WM, WN](), alignment=128]` ## Methods ### `__init__` `__init__() -> Self` ### `get_tile` `get_tile(self, stage: Int) -> SMemBuffer[dtype, layout, pipeline_stages, BM, BN, WM, WN].BlockTileType` **Returns:** `SMemBuffer`
--- ## ThreadRole
`@register_passable(trivial)` `struct ThreadRole` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Enum`](/mojo/kernels/linalg/matmul/gpu/amd/structured/Enum), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `CONSUMER` `comptime CONSUMER = ThreadRole(1)` ### `PRODUCER` `comptime PRODUCER = ThreadRole(0)` ### `PRODUCER_CONSUMER` `comptime PRODUCER_CONSUMER = ThreadRole(2)` ## Methods ### `value` `value(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `__str__` `__str__(self) -> String` Returns the string representation of this algorithm. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String: A human-readable string representation of the algorithm. ### `write_to` `write_to[W: Writer](self, mut writer: W)`
--- ## structured
## Structs * [​`AMDSharedMemoryBarrier`](./AMDSharedMemoryBarrier): * [​`AmdTileOperator`](./AmdTileOperator): Manages tensor core operations for matrix multiplication on AMD GPUs. * [​`AMDWarpSharedMemoryBarrier`](./AMDWarpSharedMemoryBarrier): * [​`MMAConfig`](./MMAConfig): * [​`SMemBuffer`](./SMemBuffer): Manages shared memory and returns 2D tile slices of the buffer. * [​`ThreadRole`](./ThreadRole): ## Traits * [​`Enum`](./Enum): ## Functions * [​`pipeline_layout`](./pipeline_layout):
--- ## pipeline_layout
`pipeline_layout[layout: Layout, pipeline_stages: Int]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## determine_thread_role
`determine_thread_role[producer_a_warps: Int, producer_b_warps: Int]() -> Tuple[ThreadRole, Int]` Returns (role, consumer\_warp\_id within role group). **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)
--- ## get_producer_warp_thread_layout
`get_producer_warp_thread_layout[k_tile_size: Int, simd_width: Int, block_rows: Int, block_cols: Int]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## warp_spec_matmul
AMD Warp-Specialized Matrix Multiplication Architecture Overview: * Producer warps: Load tiles from global to shared memory * A producers: Load M×K tiles from matrix A * B producers: Load N×K tiles from matrix B * Consumer warps: Perform matrix multiplication using shared memory tiles * Ring buffer: Coordinates producer-consumer synchronization with barriers Data Flow: 1. Producers load tiles into shared memory stages 2. Barriers ensure data is ready before consumers access it 3. Consumers compute partial results and accumulate 4. Final results written back to global memory Memory Layout: * Shared memory is divided into pipeline stages for overlapping * Each stage contains block tiles that are further divided into warp tiles * Swizzling may be applied to avoid bank conflicts Ring Buffer Configuration: * Uses SingleCounterSync strategy by default (single atomic counter per tile) * Can be changed to SplitCounterSync in the RingBuffer type aliases for reduced contention * The trait-based design allows easy experimentation with different sync strategies ## `comptime` values ### `GlobalTensor` `comptime GlobalTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL]` #### Parameters * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Functions * [​`determine_thread_role`](./determine_thread_role): Returns (role, consumer\_warp\_id within role group). * [​`get_producer_warp_thread_layout`](./get_producer_warp_thread_layout): * [​`lgkm_wait`](./lgkm_wait): * [​`run_producer`](./run_producer): Generic producer function for loading matrix tiles from global to shared memory. * [​`smem_tile_layout`](./smem_tile_layout): * [​`validate_config`](./validate_config): Validates the configuration parameters for the matrix multiplication kernel. * [​`warp_specialized_matmul`](./warp_specialized_matmul): * [​`warp_specialized_matmul_kernel`](./warp_specialized_matmul_kernel):
--- ## lgkm_wait
`lgkm_wait()`
--- ## run_producer
`run_producer[dtype: DType, layout: Layout, block_rows: Int, block_cols: Int, warp_rows: Int, warp_cols: Int, producer_warps: Int, pipeline_stages: Int, k_tile_size: Int, simd_width: Int, warps_processed_per_producer: Int, tile_count: Int, swizzle: OptionalReg[Swizzle]](matrix: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL], mut ring_buffer: RingBuffer[dtype, layout, pipeline_stages, block_rows, block_cols, warp_rows, warp_cols, reads_per_warp_block, tile_buffers, sync_strategy_type], warp_id: UInt, block_idx_dim: Int)` Generic producer function for loading matrix tiles from global to shared memory.
--- ## smem_tile_layout
`smem_tile_layout[k_tile_size: Int, block_rows: Int, block_cols: Int]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## validate_config
`validate_config[BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, m_warps: Int, n_warps: Int, producer_a: Int, producer_b: Int, consumer: Int]()` Validates the configuration parameters for the matrix multiplication kernel.
--- ## warp_specialized_matmul
`warp_specialized_matmul[M: Int, N: Int, K: Int, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, a_producer_warps: Int, b_producer_warps: Int, consumer_warps: Int, pipeline_stages: Int = 1](a_device_tensor: LayoutTensor[DType.bfloat16, Layout.row_major(M, K), origin], b_device_tensor: LayoutTensor[DType.bfloat16, Layout.row_major(N, K), origin], c_device_tensor: LayoutTensor[DType.float32, Layout.row_major(M, N), origin], ctx: DeviceContext)`
--- ## warp_specialized_matmul_kernel
`warp_specialized_matmul_kernel[in_type: DType, out_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, a_producer_warps: Int, b_producer_warps: Int, consumer_warps: Int, pipeline_stages: Int](a: LayoutTensor[in_type, a_layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL], b: LayoutTensor[in_type, b_layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL], c: LayoutTensor[out_type, c_layout, MutAnyOrigin, address_space=AddressSpace.GLOBAL])`
--- ## gpu
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Packages * [​`amd`](./amd/): Provides the AMD GPU backend implementations for matmuls. * [​`sm100`](./sm100/): Provides the Nvidia Blackwell backend implementations for matmuls. * [​`sm100_structured`](./sm100_structured/): SM100 Structured Matmul - Refactored with encapsulated pipeline management. * [​`sm80`](./sm80/): Provides the CPU Hopper backend implementations for matmuls. * [​`sm90`](./sm90/): Provides the Nvidia Hopper backend implementations for matmuls. ## Modules * [​`profiler`](./profiler/): * [​`tile_scheduler`](./tile_scheduler/): * [​`tile_scheduler_splitk`](./tile_scheduler_splitk/): ## Functions * [​`matmul_kernel`](./matmul_kernel): Matrix Multiplication using shared memory. This version loads blocks of size tile\_size x tile\_size from A and B and updates a tile\_size x tile\_size in C. The thread block should have shape (tile\_size, tile\_size, 1). Each thread is mapped one element in C. The grid should have shape (N/tile\_size, M/tile\_size, 1). N is the first dimension for coalesced access. * [​`matmul_kernel_naive`](./matmul_kernel_naive): * [​`multistage_gemm`](./multistage_gemm): * [​`split_k_reduce`](./split_k_reduce):
--- ## matmul_kernel
`matmul_kernel[c_type: DType, a_type: DType, b_type: DType, tile_size: Int, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[c_type]()](c_ptr: LegacyUnsafePointer[Scalar[c_type]], a_ptr: LegacyUnsafePointer[Scalar[a_type]], b_ptr: LegacyUnsafePointer[Scalar[b_type]], m: Int, n: Int, k: Int)` Matrix Multiplication using shared memory. This version loads blocks of size tile\_size x tile\_size from A and B and updates a tile\_size x tile\_size in C. The thread block should have shape (tile\_size, tile\_size, 1). Each thread is mapped one element in C. The grid should have shape (N/tile\_size, M/tile\_size, 1). N is the first dimension for coalesced access.
--- ## matmul_kernel_naive
`matmul_kernel_naive[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, BLOCK_DIM: Int, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, s_type: DType = get_accum_type[c_type]()](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], m: Int, n: Int, k: Int)`
--- ## multistage_gemm
`multistage_gemm[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, origin, c_shape], a: NDBuffer[a_type, 2, origin, a_shape], b: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)` `multistage_gemm[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, origin, c_shape], a: NDBuffer[a_type, 2, origin, a_shape], b: NDBuffer[b_type, 2, origin, b_shape], runtime_config: MatmulConfig[a_type, b_type, c_type, transpose_b], ctx: DeviceContext)`
--- ## BlackwellProfileWarp
`struct BlackwellProfileWarp[load_warps: UInt32, mma_warps: UInt32, scheduler_warps: UInt32, epilogue_warps: UInt32, max_entries_per_warp: UInt32, //, WorkspaceManager: BlackwellWarpProfilingWorkspaceManager[load_warps, mma_warps, scheduler_warps, epilogue_warps, max_entries_per_warp], warp_role: UInt32 = 0]` This struct calculates execution time for a warp/s, and writes a single entry to the workspace. ## Fields * ​timeline (`Tuple[UInt64, UInt64]`): * ​workspace (`Span[UInt64, MutAnyOrigin]`): * ​entry\_idx (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ### `enable_profiling` `comptime enable_profiling = (max_entries_per_warp > 0)` ## Methods ### `__init__` `__init__(out self, workspace: Span[UInt64, MutAnyOrigin], entry_idx: UInt32)` ### `__enter__` `__enter__(mut self)` ### `__exit__` `__exit__(mut self)`
--- ## BlackwellWarpProfilingWorkspaceManager
`@register_passable(trivial)` `struct BlackwellWarpProfilingWorkspaceManager[load_warps: UInt32, mma_warps: UInt32, scheduler_warps: UInt32, epilogue_warps: UInt32, max_entries_per_warp: UInt32]` This struct manages the profiling workspace. The workspaces consists of equal sized chunks, the total number of which is equal to the total number of active SMs. Each SM chunk consists of sequences of entries, with a maximum number of entries per warp role. Template Parameters: load\_warps: Number of warps specialized for load operations mma\_warps: Number of warps specialized for matrix multiply-accumulate operations scheduler\_warps: Number of warps specialized for scheduling operations epilogue\_warps: Number of warps specialized for epilogue operations max\_entries\_per\_warp: Maximum number of entries per warp (common across all warp roles) ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `entries_per_sm` `comptime entries_per_sm = max_entries_per_warp.__rmul__[DType.uint32, 1](BlackwellWarpProfilingWorkspaceManager[load_warps, mma_warps, scheduler_warps, epilogue_warps, max_entries_per_warp].total_warp_roles)` ### `header` `comptime header = "time_start,time_end,sm_id,block_idx_x,block_idx_y,role,entry_idx\n"` ### `sm_count` `comptime sm_count = B200.sm_count` ### `total_data_points` `comptime total_data_points = 7` ### `total_warp_roles` `comptime total_warp_roles = 4` ## Methods ### `get_workspace` `static get_workspace(ctx: DeviceContext) -> Span[UInt64, MutAnyOrigin]` **Returns:** [`Span`](/mojo/stdlib/memory/span/Span) ### `write_to_workspace` `static write_to_workspace[warp_role: UInt32](sm_idx: UInt32, entry_idx: UInt32, workspace: Span[UInt64, MutAnyOrigin], timeline: Tuple[UInt64, UInt64])` ### `dump_workspace_as_csv` `static dump_workspace_as_csv(ctx: DeviceContext, workspace: Span[UInt64, MutAnyOrigin], filename: StringSlice[StaticConstantOrigin])`
--- ## profiler
## `comptime` values ### `MatmulProfileWarp` `comptime MatmulProfileWarp[warp_role: UInt32, max_entries_per_warp: UInt32] = BlackwellProfileWarp[BlackwellWarpProfilingWorkspaceManager[1, 1, 1, 4, max_entries_per_warp](), warp_role]` #### Parameters * ​warp\_role ([`UInt32`](/stdlib/builtin/simd/#uint32)): * ​max\_entries\_per\_warp ([`UInt32`](/stdlib/builtin/simd/#uint32)): ### `MatmulWarpSpecializationWorkSpaceManager` `comptime MatmulWarpSpecializationWorkSpaceManager[max_entries_per_warp: UInt32] = BlackwellWarpProfilingWorkspaceManager[1, 1, 1, 4, max_entries_per_warp]` #### Parameters * ​max\_entries\_per\_warp ([`UInt32`](/stdlib/builtin/simd/#uint32)): ## Structs * [​`BlackwellProfileWarp`](./BlackwellProfileWarp): This struct calculates execution time for a warp/s, and writes a single entry to the workspace. * [​`BlackwellWarpProfilingWorkspaceManager`](./BlackwellWarpProfilingWorkspaceManager): This struct manages the profiling workspace. The workspaces consists of equal sized chunks, the total number of which is equal to the total number of active SMs. Each SM chunk consists of sequences of entries, with a maximum number of entries per warp role.
--- ## B200BlockScaledMatmulSmem
`struct B200BlockScaledMatmulSmem[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b]]` ## Fields * ​a\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].a_smem_size]`): * ​b\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].b_smem_size]`): * ​c\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].CType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].c_smem_size]`): * ​sfa\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].AScalesType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sfa_smem_size]`): * ​sfb\_smem (`InlineArray[B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BScalesType, B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sfb_smem_size]`): * ​tma\_mma\_mbars (`InlineArray[SharedMemBarrier, (Int(B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].num_group_pipeline_stages) * 2)]`): * ​accum\_mbars (`InlineArray[SharedMemBarrier, (Int(config) * 2)]`): * ​clc\_mbars\_full (`InlineArray[SharedMemBarrier, Int(config)]`): * ​clc\_mbars\_empty (`InlineArray[SharedMemBarrier, Int(config)]`): * ​clc\_throttle\_mbars (`InlineArray[SharedMemBarrier, (Int(config) * 2)]`): * ​clc\_response (`InlineArray[UInt128, Int(config)]`): * ​tmem\_dealloc\_mbar (`InlineArray[SharedMemBarrier, 1]`): * ​tmem\_addr (`InlineArray[UInt32, 1]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_size` `comptime a_smem_size = ((B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM * B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK) * Int(config))` ### `AScalesType` `comptime AScalesType = Scalar[sfa_dtype]` ### `AType` `comptime AType = Scalar[a_type]` ### `b_smem_size` `comptime b_smem_size = ((B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BN * B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BK) * Int(config))` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BScalesType` `comptime BScalesType = Scalar[sfb_dtype]` ### `BType` `comptime BType = Scalar[b_type]` ### `c_smem_size` `comptime c_smem_size = ((B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputM * B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].OutputN) * Int(config))` ### `CType` `comptime CType = Scalar[c_type]` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (config // config)` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `sf_block_atom_size` `comptime sf_block_atom_size = (((load_from_mem SF_ATOM_M.__getitem__[Int, Int, 0]()) * (load_from_mem SF_ATOM_M.__getitem__[Int, Int, 1]())) * 4)` ### `sfa_smem_size` `comptime sfa_smem_size = (((B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].BM // SF_MN_GROUP_SIZE) * B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sf_block_atom_size) * Int(config))` ### `sfb_smem_size` `comptime sfb_smem_size = (((B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].MMA_N // SF_MN_GROUP_SIZE) * B200BlockScaledMatmulSmem[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b, config=config].sf_block_atom_size) * Int(config))`
--- ## blackwell_block_scaled_matmul_tma_umma_warp_specialized
`blackwell_block_scaled_matmul_tma_umma_warp_specialized[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, sfa_dtype: DType, sfa_layout: Layout, sfb_dtype: DType, sfb_layout: Layout, transpose_b: Bool, *, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: OptionalReg[UInt32] = None](c_device: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_device: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_device: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[sfa_dtype, sfa_layout, MutAnyOrigin], b_scales: LayoutTensor[sfb_dtype, sfb_layout, MutAnyOrigin], ctx: DeviceContext)`
--- ## blackwell_block_scaled_tma_umma_warp_specialized_kernel
`blackwell_block_scaled_tma_umma_warp_specialized_kernel[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, sfa_tile_layout: Layout, sfb_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, transpose_b: Bool, config: BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_tile_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_tile_layout, sfb_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])`
--- ## consumer_main_loop (Block_scaled_matmul)
`consumer_main_loop[accum_type: DType, c_type: DType, a_type: DType, b_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_smem_layout: Layout, b_smem_layout: Layout, sfa_smem_layout: Layout, sfb_smem_layout: Layout, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool, pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], SFA_NUM_COLS: Int, SFB_NUM_COLS: Int, cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), k_group_size: UInt = 1](tmem_addr: UInt32, sfa_tmem: UInt32, sfb_tmem: UInt32, a_smem_iter: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfa_smem_iter: LayoutTensorIter[sfa_dtype, sfa_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfb_smem_iter: LayoutTensorIter[sfb_dtype, sfb_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[pipeline_stages], mma_op: MmaOpSM100_BlockScaled_SS[c_type, a_type, b_type, sfa_dtype, sfb_dtype, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)`
--- ## copy_sf_tmem
`copy_sf_tmem[sf_dtype: DType, sf_smem_layout: Layout, TILE_MN: Int, cta_group: Int](sf_smem: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sf_tmem: UInt32)`
--- ## block_scaled_matmul (Block_scaled_matmul)
## Structs * [​`B200BlockScaledMatmulSmem`](./B200BlockScaledMatmulSmem): ## Functions * [​`blackwell_block_scaled_matmul_tma_umma_warp_specialized`](./blackwell_block_scaled_matmul_tma_umma_warp_specialized): * [​`blackwell_block_scaled_tma_umma_warp_specialized_kernel`](./blackwell_block_scaled_tma_umma_warp_specialized_kernel): * [​`consumer_main_loop`](./consumer_main_loop): * [​`copy_sf_tmem`](./copy_sf_tmem): * [​`load_AB`](./load_AB):
--- ## load_AB (Block_scaled_matmul)
`load_AB[a_type: DType, b_type: DType, sfa_dtype: DType, sfb_dtype: DType, a_layout: Layout, b_layout: Layout, sfa_layout: Layout, sfb_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, sfa_desc_layout: Layout, sfb_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, sfa_smem_layout: Layout, sfb_smem_layout: Layout, num_pipeline_stages: UInt, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1, k_group_size: UInt = 1](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], sfa_tma_op: TMATensorTile[sfa_dtype, sfa_layout, sfa_desc_layout], sfb_tma_op: TMATensorTile[sfb_dtype, sfb_layout, sfb_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfa_smem: LayoutTensorIter[sfa_dtype, sfa_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], sfb_smem: LayoutTensorIter[sfb_dtype, sfb_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[Int(num_pipeline_stages)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)`
--- ## blockwise_fp8
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `smem_layout_3D` `comptime smem_layout_3D[layout: Layout] = Layout(IntTuple(IntTuple(1), layout.shape[0].owned_copy(), layout.shape[1].owned_copy(), Tuple[]()), IntTuple(IntTuple(0), layout.stride[0].owned_copy(), layout.stride[1].owned_copy(), Tuple[]()))` #### Parameters * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Functions * [​`matmul_sm100_blockwise_scaled_fp8`](./matmul_sm100_blockwise_scaled_fp8): * [​`matmul_sm100_blockwise_scaled_fp8_1d2d_kernel`](./matmul_sm100_blockwise_scaled_fp8_1d2d_kernel): * [​`matmul_sm100_blockwise_scaled_fp8_1d2d_wrapper`](./matmul_sm100_blockwise_scaled_fp8_1d2d_wrapper):
--- ## matmul_sm100_blockwise_scaled_fp8
`matmul_sm100_blockwise_scaled_fp8[a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, c_type: DType, a_type: DType, b_type: DType, a_scales_type: DType, b_scales_type: DType, *, transpose_b: Bool, umma_shape: IndexList[3], block_tile_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## matmul_sm100_blockwise_scaled_fp8_1d2d_kernel
`matmul_sm100_blockwise_scaled_fp8_1d2d_kernel[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, a_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_tile_layout: Layout, b_tile_layout: Layout, a_scales_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: UInt = 128, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_layout, a_scales_desc_layout], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], num_iters: UInt)`
--- ## matmul_sm100_blockwise_scaled_fp8_1d2d_wrapper
`matmul_sm100_blockwise_scaled_fp8_1d2d_wrapper[a_type: DType, b_type: DType, c_type: DType, a_scales_type: DType, b_scales_type: DType, a_layout: Layout, c_layout: Layout, a_scales_layout: Layout, b_scales_layout: Layout, a_tile_layout: Layout, b_tile_layout: Layout, a_scales_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: UInt = 128, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_layout, a_scales_desc_layout], b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], num_iters: UInt)`
--- ## LoadOp
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `args_type` `comptime args_type` ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ## Required methods ### `__init__` `__init__(out self: _Self, args: _Self.args_type)` **Returns:** `_Self` ### `__call__` `__call__(self: _Self, a_smem_tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], m: UInt32, n: UInt32, k: UInt32, ref [3] mbar: SharedMemBarrier)` ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name.
--- ## MmaOp (Composable)
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__call__` `__call__(self: _Self)`
--- ## OpArgs
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## OutputOp
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `args_type` `comptime args_type` ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ## Required methods ### `__init__` `__init__(out self: _Self, args: _Self.args_type)` **Returns:** `_Self` ### `__call__` `__call__(self: _Self, tmem_addr: UInt32)` ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name.
--- ## Pipeline
`struct Pipeline[a_type: DType, b_type: DType, c_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, loadop_t: LoadOp, outputop_t: OutputOp]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`PipelineOp`](/mojo/kernels/linalg/matmul/gpu/sm100/composable/PipelineOp), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `args_type` `comptime args_type = PipelineArgs[loadop_t, outputop_t]` ## Methods ### `run` `static run(args: PipelineArgs[loadop_t, outputop_t])`
--- ## PipelineArgs
`struct PipelineArgs[loadop_t: LoadOp, outputop_t: OutputOp]` ## Fields * ​load\_args (`loadop_t.args_type`): * ​output\_args (`outputop_t.args_type`): * ​num\_iters (`UInt`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`OpArgs`](/mojo/kernels/linalg/matmul/gpu/sm100/composable/OpArgs), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True if outputop_t.args_type.__copyinit__is_trivial if loadop_t.args_type.__copyinit__is_trivial else loadop_t.args_type.__copyinit__is_trivial else outputop_t.args_type.__copyinit__is_trivial if loadop_t.args_type.__copyinit__is_trivial else loadop_t.args_type.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = True if outputop_t.args_type.__del__is_trivial if loadop_t.args_type.__del__is_trivial else loadop_t.args_type.__del__is_trivial else outputop_t.args_type.__del__is_trivial if loadop_t.args_type.__del__is_trivial else loadop_t.args_type.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True if outputop_t.args_type.__moveinit__is_trivial if loadop_t.args_type.__moveinit__is_trivial else loadop_t.args_type.__moveinit__is_trivial else outputop_t.args_type.__moveinit__is_trivial if loadop_t.args_type.__moveinit__is_trivial else loadop_t.args_type.__moveinit__is_trivial` ### `device_type` `comptime device_type = PipelineArgs[loadop_t, outputop_t]` ## Methods ### `__init__` `__init__(out self, load_args: loadop_t.args_type, output_args: outputop_t.args_type, num_iters: UInt)` ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String)
--- ## PipelineOp
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `args_type` `comptime args_type` ## Required methods ### `run` `static run(args: _Self.args_type)`
--- ## R2GOutputOp
`struct R2GOutputOp[accum_type: DType, dtype: DType, layout: Layout, num_threads: Int, mma_shape: IndexList[3], block_tile_shape: IndexList[3], o: MutOrigin]` ## Fields * ​c (`LayoutTensor[dtype, layout, o]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`OutputOp`](/mojo/kernels/linalg/matmul/gpu/sm100/composable/OutputOp), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `args_type` `comptime args_type = STOutputOpArgs[dtype, layout, o]` ### `device_type` `comptime device_type = R2GOutputOp[accum_type, dtype, layout, num_threads, mma_shape, block_tile_shape, o]` ## Methods ### `__init__` `__init__(out self, args: STOutputOpArgs[dtype, layout, o])` ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `to_kernel_args` `static to_kernel_args(c: LayoutTensor[dtype, layout, o], ctx: DeviceContext) -> R2GOutputOp[accum_type, dtype, layout, num_threads, mma_shape, block_tile_shape, o].args_type` **Returns:** `R2GOutputOp` ### `__call__` `__call__(self, tmem_addr: UInt32)`
--- ## STOutputOpArgs
`struct STOutputOpArgs[mut: Bool, //, dtype: DType, layout: Layout, sb: Origin[mut]]` ## Fields * ​c (`LayoutTensor[dtype, layout, sb]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`OpArgs`](/mojo/kernels/linalg/matmul/gpu/sm100/composable/OpArgs), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, c: LayoutTensor[dtype, layout, sb])` ### `__copyinit__` `__copyinit__(out self, other: Self)`
--- ## TMALoadOp
`struct TMALoadOp[a_type: DType, b_type: DType, block_tile_shape: IndexList[3], cluster_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B]` ## Fields * ​a\_tma\_ptr (`LegacyUnsafePointer[TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].a_tma_type]`): * ​b\_tma\_ptr (`LegacyUnsafePointer[TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].b_tma_type]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`LoadOp`](/mojo/kernels/linalg/matmul/gpu/sm100/composable/LoadOp), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_tma_desc_layout` `comptime a_tma_desc_layout = _tma_desc_tile_layout[a_type, 2, Index((block_tile_shape.__getitem__[3, DType.int64, Int](0) // cluster_shape.__getitem__[3, DType.int64, Int](0)), block_tile_shape.__getitem__[3, DType.int64, Int](2)), a_swizzle]()` ### `a_tma_layout` `comptime a_tma_layout = Layout.row_major((block_tile_shape.__getitem__[3, DType.int64, Int](0) // cluster_shape.__getitem__[3, DType.int64, Int](0)), block_tile_shape.__getitem__[3, DType.int64, Int](2))` ### `a_tma_type` `comptime a_tma_type = TMATensorTile[a_type, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].a_tma_layout, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].a_tma_desc_layout]` ### `args_type` `comptime args_type = TMALoadOpArgs[a_type, b_type, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].a_tma_layout, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].b_tma_layout, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].a_tma_desc_layout, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].b_tma_desc_layout]` ### `b_tma_desc_layout` `comptime b_tma_desc_layout = _tma_desc_tile_layout[b_type, 2, Index((block_tile_shape.__getitem__[3, DType.int64, Int](1) // cluster_shape.__getitem__[3, DType.int64, Int](1)), block_tile_shape.__getitem__[3, DType.int64, Int](2)), b_swizzle]()` ### `b_tma_layout` `comptime b_tma_layout = Layout.row_major((block_tile_shape.__getitem__[3, DType.int64, Int](1) // cluster_shape.__getitem__[3, DType.int64, Int](1)), block_tile_shape.__getitem__[3, DType.int64, Int](2))` ### `b_tma_type` `comptime b_tma_type = TMATensorTile[b_type, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].b_tma_layout, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].b_tma_desc_layout]` ### `device_type` `comptime device_type = TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle]` ## Methods ### `__init__` `__init__(out self, args: TMALoadOpArgs[a_type, b_type, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].a_tma_layout, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].b_tma_layout, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].a_tma_desc_layout, TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].b_tma_desc_layout])` ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `to_kernel_args` `static to_kernel_args(a: LayoutTensor[a_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext) -> TMALoadOp[a_type, b_type, block_tile_shape, cluster_shape, a_swizzle, b_swizzle].args_type` **Returns:** `TMALoadOp` ### `__call__` `__call__(self, a_smem_tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_smem_tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], m: UInt32, n: UInt32, k: UInt32, ref [3] mbar: SharedMemBarrier)`
--- ## TMALoadOpArgs
`struct TMALoadOpArgs[a_type: DType, b_type: DType, a_layout: Layout, b_layout: Layout, a_desc_layout: Layout = a_layout, b_desc_layout: Layout = b_layout]` ## Fields * ​a\_tma\_op (`TMATensorTile[a_type, a_layout, a_desc_layout]`): * ​b\_tma\_op (`TMATensorTile[b_type, b_layout, b_desc_layout]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`OpArgs`](/mojo/kernels/linalg/matmul/gpu/sm100/composable/OpArgs), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, a: TMATensorTile[a_type, a_layout, a_desc_layout], b: TMATensorTile[b_type, b_layout, b_desc_layout])` ### `__copyinit__` `__copyinit__(out self, other: Self)`
--- ## composable
## Structs * [​`Pipeline`](./Pipeline): * [​`PipelineArgs`](./PipelineArgs): * [​`R2GOutputOp`](./R2GOutputOp): * [​`STOutputOpArgs`](./STOutputOpArgs): * [​`TMALoadOp`](./TMALoadOp): * [​`TMALoadOpArgs`](./TMALoadOpArgs): ## Traits * [​`LoadOp`](./LoadOp): * [​`MmaOp`](./MmaOp): * [​`OpArgs`](./OpArgs): * [​`OutputOp`](./OutputOp): * [​`PipelineOp`](./PipelineOp): ## Functions * [​`matmul_kernel`](./matmul_kernel): * [​`matmul_sm100`](./matmul_sm100):
--- ## matmul_kernel (Composable)
`matmul_kernel[pipeline_t: PipelineOp](args: pipeline_t.args_type)`
--- ## matmul_sm100
`matmul_sm100[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, mma_shape: IndexList[3], block_tile_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B](c_device: NDBuffer[c_type, 2, origin, c_shape], a_device: NDBuffer[a_type, 2, origin, a_shape], b_device: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)`
--- ## BlockScaledMatmulConfig
`@register_passable(trivial)` `struct BlockScaledMatmulConfig[a_type: DType, b_type: DType, c_type: DType, sfa_dtype: DType, sfb_dtype: DType, transpose_b: Bool = True]` Static configuration of GPU matmul. ## Fields * ​cta\_group (`Int`): * ​mma\_shape (`IndexList[3]`): * ​cluster\_shape (`IndexList[3]`): * ​AB\_swapped (`Bool`): * ​block\_swizzle\_size (`Int`): * ​raster\_order (`RasterOrder`): * ​block\_tile\_shape (`IndexList[3]`): * ​num\_split\_k (`Int`): * ​num\_pipeline\_stages (`UInt`): * ​num\_clc\_pipeline\_stages (`UInt`): * ​num\_accum\_pipeline\_stages (`UInt`): * ​num\_output\_stages (`UInt`): * ​output\_tile\_shape (`IndexList[2]`): * ​a\_swizzle (`TensorMapSwizzle`): * ​b\_swizzle (`TensorMapSwizzle`): * ​c\_swizzle (`TensorMapSwizzle`): * ​k\_group\_size (`UInt`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ### `sf_block_atom_size` `comptime sf_block_atom_size = (((load_from_mem SF_ATOM_M.__getitem__[Int, Int, 0]()) * (load_from_mem SF_ATOM_M.__getitem__[Int, Int, 1]())) * 4)` ## Methods ### `__init__` `__init__(*, cta_group: Int = 2, mma_shape: IndexList[3] = get_mma_shape[a_type, BlockScaledMatmulConfig[a_type, b_type, c_type, sfa_dtype, sfb_dtype, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index(2, 1, 1), AB_swapped: Bool = False, num_split_k: Int = 1, block_swizzle_size: Int = 0, raster_order: RasterOrder = RasterOrder.AlongM, k_group_size: UInt = 1, num_pipeline_stages: Optional[UInt] = None, num_accum_pipeline_stages: UInt = 2, num_clc_pipeline_stages: UInt = 2) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `swap_AB_type` `swap_AB_type(self) -> BlockScaledMatmulConfig[b_type, a_type, c_type, sfa_dtype, sfb_dtype, transpose_b]` **Returns:** [`BlockScaledMatmulConfig`](/mojo/kernels/linalg/matmul/gpu/sm100/config/BlockScaledMatmulConfig) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)` ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with the underlying bytes. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance.
--- ## MatmulConfig
`@register_passable(trivial)` `struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True]` Static configuration of GPU matmul. ## Fields * ​cta\_group (`Int`): * ​mma\_shape (`IndexList[3]`): * ​cluster\_shape (`IndexList[3]`): * ​AB\_swapped (`Bool`): * ​block\_swizzle\_size (`Int`): * ​raster\_order (`RasterOrder`): * ​block\_tile\_shape (`IndexList[3]`): * ​num\_split\_k (`Int`): * ​num\_pipeline\_stages (`UInt`): * ​num\_clc\_pipeline\_stages (`UInt`): * ​num\_accum\_pipeline\_stages (`UInt`): * ​num\_output\_stages (`UInt`): * ​output\_tile\_shape (`IndexList[2]`): * ​a\_swizzle (`TensorMapSwizzle`): * ​b\_swizzle (`TensorMapSwizzle`): * ​c\_swizzle (`TensorMapSwizzle`): * ​k\_group\_size (`UInt`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ## Methods ### `__init__` `__init__(*, cta_group: Int = 2, mma_shape: IndexList[3] = get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index(2, 1, 1), AB_swapped: Bool = False, num_split_k: Int = 1, block_swizzle_size: Int = 0, raster_order: RasterOrder = RasterOrder.AlongM, k_group_size: UInt = 1, num_pipeline_stages: Optional[UInt] = None, num_accum_pipeline_stages: UInt = 2, num_clc_pipeline_stages: UInt = 2) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `swap_AB_type` `swap_AB_type(self) -> MatmulConfig[b_type, a_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)` ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with the underlying bytes. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance.
--- ## build_configs
`build_configs[a_type: DType, b_type: DType, c_type: DType, N: Int, K: Int, transpose_b: Bool = True]() -> Set[MatmulConfig[a_type, b_type, c_type, transpose_b]]` **Returns:** `Set`
--- ## choose_config
`choose_config[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = True](M: Int, N: Int, K: Int) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## config
## Structs * [​`BlockScaledMatmulConfig`](./BlockScaledMatmulConfig): Static configuration of GPU matmul. * [​`MatmulConfig`](./MatmulConfig): Static configuration of GPU matmul. ## Functions * [​`build_configs`](./build_configs): * [​`choose_config`](./choose_config):
--- ## heuristic_and_outliers_dispatch
`heuristic_and_outliers_dispatch[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## dispatch
## `comptime` values ### `DISPATCH_HIT` `comptime DISPATCH_HIT = 1` ### `DISPATCH_MISS` `comptime DISPATCH_MISS = 0` ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Functions * [​`heuristic_and_outliers_dispatch`](./heuristic_and_outliers_dispatch): * [​`matmul_dispatch_sm100`](./matmul_dispatch_sm100): * [​`matmul_dispatch_sm100_bf16`](./matmul_dispatch_sm100_bf16): * [​`matmul_dispatch_sm100_fp8`](./matmul_dispatch_sm100_fp8):
--- ## matmul_dispatch_sm100
`matmul_dispatch_sm100[c_type: DType, a_type: DType, b_type: DType, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) `matmul_dispatch_sm100[c_type: DType, a_type: DType, b_type: DType, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_lambda_wrapper: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext)`
--- ## matmul_dispatch_sm100_bf16
`matmul_dispatch_sm100_bf16[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## matmul_dispatch_sm100_fp8
`matmul_dispatch_sm100_fp8[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## sm100 (Sm100)
Provides the Nvidia Blackwell backend implementations for matmuls. ## Modules * [​`block_scaled_matmul`](./block_scaled_matmul/): * [​`blockwise_fp8`](./blockwise_fp8/): * [​`composable`](./composable/): * [​`config`](./config/): * [​`dispatch`](./dispatch/): * [​`matmul`](./matmul/): * [​`pipeline`](./pipeline/): * [​`tile_scheduler`](./tile_scheduler/): * [​`tile_scheduler_splitk`](./tile_scheduler_splitk/): * [​`tuning_configs`](./tuning_configs/): * [​`warp_specialized_blockwise_fp8`](./warp_specialized_blockwise_fp8/):
--- ## B200MatmulSmem
`struct B200MatmulSmem[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]]` ## Fields * ​a\_smem (`InlineArray[B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].AType, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].a_smem_size]`): * ​b\_smem (`InlineArray[B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BType, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].b_smem_size]`): * ​c\_smem (`InlineArray[B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].CType, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].c_smem_size]`): * ​tma\_mma\_mbars (`InlineArray[SharedMemBarrier, (Int(B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_group_pipeline_stages) * 2)]`): * ​accum\_mbars (`InlineArray[SharedMemBarrier, (Int(config) * 2)]`): * ​clc\_mbars\_full (`InlineArray[SharedMemBarrier, Int(config)]`): * ​clc\_mbars\_empty (`InlineArray[SharedMemBarrier, Int(config)]`): * ​clc\_throttle\_mbars (`InlineArray[SharedMemBarrier, (Int(config) * 2)]`): * ​clc\_response (`InlineArray[UInt128, Int(config)]`): * ​tmem\_dealloc\_mbar (`InlineArray[SharedMemBarrier, 1]`): * ​tmem\_addr (`InlineArray[UInt32, 1]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_size` `comptime a_smem_size = ((B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BM * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK) * Int(config))` ### `AType` `comptime AType = Scalar[a_type]` ### `b_smem_size` `comptime b_smem_size = ((B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BN * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK) * Int(config))` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BType` `comptime BType = Scalar[b_type]` ### `c_smem_size` `comptime c_smem_size = ((B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputM * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputN) * Int(config))` ### `CType` `comptime CType = Scalar[c_type]` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (config // config)` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)`
--- ## WarpRole (Matmul)
`@register_passable(trivial)` `struct WarpRole` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Epilogue` `comptime Epilogue = WarpRole(3)` ### `MainLoad` `comptime MainLoad = WarpRole(5)` ### `Mma` `comptime Mma = WarpRole(6)` ### `Scheduler` `comptime Scheduler = WarpRole(4)` ## Methods ### `__eq__` `__eq__(self, other: UInt) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ge__` `__ge__(self, other: UInt) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_main_load` `static is_main_load() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_mma` `static is_mma() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_epilogue` `static is_epilogue() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_scheduler` `static is_scheduler() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## accum_arrive
`accum_arrive[cta_group: Int](mma_output_pipeline: ProducerConsumerPipeline[num_stages], mma_output_stage: UInt32)`
--- ## blackwell_matmul_tma_umma_warp_specialized
`blackwell_matmul_tma_umma_warp_specialized[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: OptionalReg[UInt32] = None](c_device: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_device: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_device: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## blackwell_tma_umma_warp_specialized_kernel (Matmul)
`blackwell_tma_umma_warp_specialized_kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])`
--- ## blackwell_tma_umma_warp_specialized_split_k_kernel
`blackwell_tma_umma_warp_specialized_split_k_kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, reduction_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, max_profiled_tiles_per_SM: UInt32 = 0](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], reduction_tensor: LayoutTensor[MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type, reduction_layout, MutAnyOrigin], lock_ptr: LegacyUnsafePointer[UInt8], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])`
--- ## consumer_main_loop (Matmul)
`consumer_main_loop[accum_type: DType, c_type: DType, a_type: DType, b_type: DType, a_smem_layout: Layout, b_smem_layout: Layout, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool, pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), k_group_size: UInt = 1](tmem_addr: UInt32, a_smem_iter: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[pipeline_stages], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)`
--- ## copy_accum_to_gmem
`copy_accum_to_gmem[c_type: DType, c_layout: Layout, c_smem_layout: Layout, c_desc_layout: Layout, num_accum_pipeline_stages: Int, /, *, repeat: Int, accum_type: DType, cta_group: Int, epilogue_dtype: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], num_output_warps: UInt, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, transpose_c: Bool = False](c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], mma_output_pipeline: ProducerConsumerPipeline[num_accum_pipeline_stages], mma_output_stage: UInt32, tmem_offset: UInt32, c_coord: Tuple[UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])`
--- ## f32_frag_to_smem
`f32_frag_to_smem[swizzle_mode: TensorMapSwizzle, stageN: UInt](vec: SIMD[dtype, size], dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## matmul (3)
## `comptime` values ### `RLayout32Bits` `comptime RLayout32Bits[layout: Layout] = RuntimeLayout[layout, element_type=DType.uint32, linear_idx_type=DType.uint32]` #### Parameters * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Structs * [​`B200MatmulSmem`](./B200MatmulSmem): * [​`WarpRole`](./WarpRole): ## Functions * [​`accum_arrive`](./accum_arrive): * [​`blackwell_matmul_tma_umma_warp_specialized`](./blackwell_matmul_tma_umma_warp_specialized): * [​`blackwell_tma_umma_warp_specialized_kernel`](./blackwell_tma_umma_warp_specialized_kernel): * [​`blackwell_tma_umma_warp_specialized_split_k_kernel`](./blackwell_tma_umma_warp_specialized_split_k_kernel): * [​`consumer_main_loop`](./consumer_main_loop): * [​`copy_accum_to_gmem`](./copy_accum_to_gmem): * [​`f32_frag_to_smem`](./f32_frag_to_smem): * [​`load_AB`](./load_AB): * [​`matmul_sm100_fallback`](./matmul_sm100_fallback): * [​`matmul_sm100_fallback_kernel`](./matmul_sm100_fallback_kernel): * [​`multi_stage_store_C`](./multi_stage_store_C): * [​`multi_stage_store_C_split_k`](./multi_stage_store_C_split_k): * [​`register_epilogue`](./register_epilogue): * [​`shared_memory_epilogue`](./shared_memory_epilogue): * [​`shared_memory_epilogue_transpose`](./shared_memory_epilogue_transpose): * [​`stsm_helper`](./stsm_helper):
--- ## load_AB (Matmul)
`load_AB[a_type: DType, b_type: DType, a_layout: Layout, b_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, num_pipeline_stages: UInt, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1, k_group_size: UInt = 1](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[Int(num_pipeline_stages)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt32, elect_one_cta: Bool)`
--- ## matmul_sm100_fallback
`matmul_sm100_fallback[a_layout: Layout, b_layout: Layout, c_layout: Layout, c_type: DType, a_type: DType, b_type: DType, *, transpose_b: Bool, umma_shape: IndexList[3], block_tile_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## matmul_sm100_fallback_kernel
`matmul_sm100_fallback_kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: UInt = 128, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], num_iters: UInt)`
--- ## multi_stage_store_C (Matmul)
`multi_stage_store_C[c_type: DType, c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, num_accum_pipeline_stages: UInt, /, *, input_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], stage_stride_cols: UInt, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 1, num_output_warps: UInt = 4, max_tmem_cols: UInt = 512, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, transpose_c: Bool = False](c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], mma_output_pipeline: ProducerConsumerPipeline[Int(num_accum_pipeline_stages)], tmem_addr: UInt32, work_tile_coord: Tuple[UInt32, UInt32], elect_one_warp: Bool, M: UInt32, N: UInt32)`
--- ## multi_stage_store_C_split_k
`multi_stage_store_C_split_k[c_type: DType, c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, reduction_layout: Layout, num_accum_pipeline_stages: UInt, /, *, input_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], stage_stride_cols: UInt, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 1, num_output_warps: UInt = 4, max_tmem_cols: UInt = 512, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, transpose_c: Bool = False](scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], reduction_tensor: LayoutTensor[accum_type, reduction_layout, MutAnyOrigin], c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], mma_output_pipeline: ProducerConsumerPipeline[Int(num_accum_pipeline_stages)], tmem_addr: UInt32, work_info: WorkInfo, elect_one_warp: Bool, M: UInt32, N: UInt32)`
--- ## register_epilogue
`register_epilogue[MMA_M: UInt, data_paths: UInt, num_stages: UInt, bits: UInt, stage: UInt, stageN: UInt, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: UInt, epilogue_dtype: DType, frag_size: UInt, repeats: UInt, transpose_c: Bool, cta_group: Int, is_lower_frag_required: Bool](mut upper_frag_casted: SIMD[epilogue_dtype, Int(frag_size)], mut lower_frag_casted: SIMD[epilogue_dtype, Int(frag_size)], c_row: UInt32, c_col: UInt32, N: UInt32)`
--- ## shared_memory_epilogue
`shared_memory_epilogue[MMA_M: UInt, data_paths: UInt, num_stages: UInt, stage: UInt, stageN: UInt, c_type: DType, shared_n: UInt, simd_size: UInt, c_smem_upper_layout: Layout, c_smem_lower_layout: Layout, swizzle: Swizzle, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: UInt](M: UInt32, N: UInt32, c_col: UInt, c_row: UInt, c_smem_warp_tile_upper: LayoutTensor[c_type, c_smem_upper_layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_smem_warp_tile_lower: LayoutTensor[c_type, c_smem_lower_layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## shared_memory_epilogue_transpose
`shared_memory_epilogue_transpose[stage: UInt, stageN: UInt, c_type: DType, c_smem_layout: Layout, swizzle: Swizzle, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: UInt, warp_dim: UInt, MMA_M: Int, BN: Int, cta_group: Int](M: UInt32, N: UInt32, c_col: UInt, c_row: UInt, c_smem: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_i: UInt, warp_j: UInt)`
--- ## stsm_helper (Matmul)
`stsm_helper[swizzle: Swizzle, stageN: UInt, transpose_c: Bool = False, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B](vec: SIMD[dtype, size], dst: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_offset: UInt32 = 0)`
--- ## ProducerConsumerPipeline
`@register_passable(trivial)` `struct ProducerConsumerPipeline[num_stages: Int]` A producer-consumer pipeline using shared memory barriers to enforce synchronization (between producer and consumer warps). This struct is commonly used with warp specialization to pipeline operations between two warps/warpgroups with data dependencies. ## Parameters * ​num\_stages ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of pipeline stages. ## Fields * ​full (`MbarPtr`): * ​empty (`MbarPtr`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` Initialize the producer-consumer pipeline with default phases. **Args:** * ​ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory barriers. ### `wait_producer` `wait_producer(self)` Consumer waits for producer. ### `wait_consumer` `wait_consumer(self)` Producer waits for consumer. ### `producer_mbar` `producer_mbar(self, stage: UInt32) -> MbarPtr` Get the producer barrier for a specific stage. **Args:** * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The pipeline stage. **Returns:** `MbarPtr`: The shared memory barrier that the producer signals. ### `consumer_mbar` `consumer_mbar(self, stage: UInt32) -> MbarPtr` Get the consumer barrier for a specific stage. **Args:** * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The pipeline stage. **Returns:** `MbarPtr`: The shared memory barrier that the consumer signals. ### `producer_stage` `producer_stage(self) -> UInt32` Get the current producer stage index. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The current stage index for the producer (0 to num\_stages-1). ### `consumer_stage` `consumer_stage(self) -> UInt32` Get the current consumer stage index. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The current stage index for the consumer (0 to num\_stages-1). ### `consumer_step` `consumer_step(mut self)` Advance the consumer to the next pipeline stage. Increments the consumer stage and wraps to 0 when reaching num\_stages, toggling the phase bit on wrap-around. Only switch phase at end of pipeline because we assume all barriers are at the same consumer/producer phase before checked. Once checked, the execution moves to next barrier. ### `producer_step` `producer_step(mut self)` Advance the producer to the next pipeline stage. Increments the producer stage and wraps to 0 when reaching num\_stages, toggling the phase bit on wrap-around. ### `smem_bytes` `static smem_bytes() -> UInt32` Calculate the shared memory bytes required for pipeline barriers. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The total number of bytes needed for all pipeline barriers (2 \* num\_stages barriers). ### `init_mbars` `init_mbars(self, producer_arrive_count: Int32, consumer_arrive_count: Int32)` Initialize the smem barriers for the producer and consumer. This function must be called by a single thread and must be called before any the pipeline object is used. **Args:** * ​producer\_arrive\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): The number of threads that will arrive at the barrier marking data as produced. * ​consumer\_arrive\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): The number of threads that will arrive at the barrier marking data as consumed.
--- ## pipeline (Pipeline)
## `comptime` values ### `MbarPtr` `comptime MbarPtr = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` ## Structs * [​`ProducerConsumerPipeline`](./ProducerConsumerPipeline): A producer-consumer pipeline using shared memory barriers to enforce synchronization (between producer and consumer warps).
--- ## TileScheduler (Tile_scheduler)
`@register_passable(trivial)` `struct TileScheduler[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8]` ## Fields * ​cluster\_dim (`StaticTuple[Int32, 3]`): * ​log\_cluster\_dim\_m (`FastDiv[DType.uint32]`): * ​log\_cluster\_dim\_n (`FastDiv[DType.uint32]`): * ​log\_cluster\_dim\_k (`FastDiv[DType.uint32]`): * ​clc\_response (`LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]`): * ​full\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​empty\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `cluster_size` `comptime cluster_size = ((cluster_shape.__getitem__[3, DType.uint32, Int](0) * cluster_shape.__getitem__[3, DType.uint32, Int](1)) * cluster_shape.__getitem__[3, DType.uint32, Int](2))` ### `log_cluster_k` `comptime log_cluster_k = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](2))` ### `log_cluster_m` `comptime log_cluster_m = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](0))` ### `log_cluster_n` `comptime log_cluster_n = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](1))` ## Methods ### `__init__` `__init__(cluster_dim: StaticTuple[Int32, 3], clc_response_ptr: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], full_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `work_info_from_clc_response` `static work_info_from_clc_response(result: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `work_info_from_cluster` `static work_info_from_cluster(work_info: WorkInfo, cluster_dim: StaticTuple[Int32, 3], log_cluster_dim_m: FastDiv[DType.uint32], log_cluster_dim_n: FastDiv[DType.uint32]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `initial_work_info` `initial_work_info(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `fetch_next_work` `fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_to_next_work` `advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]` **Returns:** [`PipelineState`](/mojo/kernels/layout/tma_async/PipelineState)
--- ## WorkInfo (Tile_scheduler)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## tile_scheduler
## Structs * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo):
--- ## TileScheduler (Tile_scheduler_splitk)
`@register_passable(trivial)` `struct TileScheduler[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8, num_split_k: Int = 1]` ## Fields * ​locks\_ptr (`LegacyUnsafePointer[Int32]`): * ​scheduler (`TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler`): * ​total\_k\_tiles (`UInt32`): * ​k\_tiles\_per\_split (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BK` `comptime BK = reduction_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = reduction_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = reduction_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `ROW_SIZE` `comptime ROW_SIZE = reduction_tile_shape.__getitem__[3, DType.int64, Int](1) if (reduction_tile_shape.__getitem__[3, DType.int64, Int](0) == 128) else (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].MMA_N // 2)` ### `UnderlyingScheduler` `comptime UnderlyingScheduler = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ## Methods ### `__init__` `__init__(cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], clc_response_ptr: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], full_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], locks_ptr: LegacyUnsafePointer[UInt8]) -> Self` ### `convert_to_splitk_work_info` `convert_to_splitk_work_info(self, work_info: WorkInfo) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `initial_work_info` `initial_work_info(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_to_next_work` `advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]` **Returns:** [`PipelineState`](/mojo/kernels/layout/tma_async/PipelineState) ### `fetch_next_work` `fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `is_last_split` `is_last_split(self, work_tile_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `output_tile_index` `output_tile_index(self, work_info: WorkInfo) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `store_to_workspace` `store_to_workspace[accum_type: DType, workspace_layout: Layout, /, *, do_reduction: Bool = False, write_back: Bool = False](self, tmem_addr: UInt32, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], epilogue_thread_idx: UInt, reduction_tile_idx: UInt32)` ### `reduction` `reduction[accum_type: DType, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], tmem_addr: UInt32, epilogue_thread_idx: UInt, work_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `wait_eq` `static wait_eq(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)` ### `wait_lt` `static wait_lt(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)` ### `arrive_set` `static arrive_set(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)`
--- ## WorkInfo (Tile_scheduler_splitk)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​num\_k\_tiles (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `INVALID_WORK_INFO` `comptime INVALID_WORK_INFO = WorkInfo(0, 0, 0, 0, False)` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_final_split` `is_final_split(self, k_tiles_per_output_tile: UInt32) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## get_num_tiles
`get_num_tiles(problem_shape: IndexList[3], block_tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_required_locks_buffer_size_bytes
`get_required_locks_buffer_size_bytes[accum_type: DType](problem_shape: IndexList[3], block_tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## tile_scheduler_splitk
## Structs * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo): ## Functions * [​`get_num_tiles`](./get_num_tiles): * [​`get_required_locks_buffer_size_bytes`](./get_required_locks_buffer_size_bytes):
--- ## TuningConfigSM100
`@register_passable(trivial)` `struct TuningConfigSM100` ## Fields * ​M (`Int`): * ​M\_end (`Int`): * ​N (`Int`): * ​K (`Int`): * ​mma\_shape (`IndexList[3]`): * ​block\_tile\_shape (`IndexList[3]`): * ​cluster\_shape (`IndexList[3]`): * ​block\_swizzle\_size (`UInt`): * ​rasterize\_order (`RasterOrder`): * ​cta\_group (`Int`): * ​swapAB (`Bool`): * ​k\_group\_size (`UInt`): * ​num\_accum\_pipeline\_stages (`UInt`): * ​num\_clc\_pipeline\_stages (`UInt`): * ​num\_split\_k (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`TuningConfig`](/mojo/kernels/internal_utils/dispatch_utils/TuningConfig), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(M: Int, N: Int, K: Int, mma_shape: IndexList[3], block_tile_shape: IndexList[3], cluster_shape: IndexList[3], block_swizzle_size: UInt, rasterize_order: RasterOrder, cta_group: Int = 2, swapAB: Bool = False, k_group_size: UInt = 1, num_accum_pipeline_stages: UInt = 2, num_clc_pipeline_stages: UInt = 2, num_split_k: Int = 1) -> Self` `__init__(M: Int, M_end: Int, N: Int, K: Int, mma_shape: IndexList[3], cta_group: Int, cluster_shape: IndexList[3], block_swizzle_size: UInt, rasterize_order: RasterOrder, swapAB: Bool = False, k_group_size: UInt = 1, num_accum_pipeline_stages: UInt = 2, num_clc_pipeline_stages: UInt = 2, num_split_k: Int = 1) -> Self` ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String)
--- ## tuning_configs
## Structs * [​`TuningConfigSM100`](./TuningConfigSM100):
--- ## blackwell_tma_umma_warp_specialized_blockwise_fp8_kernel
`blackwell_tma_umma_warp_specialized_blockwise_fp8_kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_scales_tile_layout: Layout, a_scales_type: DType, b_scales_type: DType, b_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, a_scales_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], num_pipeline_stages: UInt, cluster_shape: StaticTuple[Int32, 3]](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_tile_layout, a_scales_desc_layout], cluster_dim: StaticTuple[Int32, 3], num_iters: UInt, b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], problem_shape: StaticTuple[Int32, 3])`
--- ## warp_specialized_blockwise_fp8
## Functions * [​`blackwell_tma_umma_warp_specialized_blockwise_fp8_kernel`](./blackwell_tma_umma_warp_specialized_blockwise_fp8_kernel): * [​`load_AB`](./load_AB): * [​`multi_stage_reg_epilogue`](./multi_stage_reg_epilogue): * [​`promote_accumulators`](./promote_accumulators): * [​`sm100_warp_specialized_blockwise_fp8`](./sm100_warp_specialized_blockwise_fp8):
--- ## load_AB (Warp_specialized_blockwise_fp8)
`load_AB[a_type: DType, b_type: DType, a_scales_type: DType, a_layout: Layout, b_layout: Layout, a_scales_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, a_scales_desc_layout: Layout, a_smem_layout: Layout, b_smem_layout: Layout, a_scales_smem_layout: Layout, num_pipeline_stages: UInt, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], a_scales_tma_op: TMATensorTile[a_scales_type, a_scales_layout, a_scales_desc_layout], a_smem: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], a_scales_smem: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[Int(num_pipeline_stages)], peer_cta_coord: Tuple[UInt, UInt, UInt], work_tile_coord: Tuple[UInt, UInt], a_multicast_mask: UInt16, b_multicast_mask: UInt16, iter_idx: UInt, elect_one_cta: Bool)`
--- ## multi_stage_reg_epilogue (Warp_specialized_blockwise_fp8)
`multi_stage_reg_epilogue[c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, accum_type: DType, accum_layout: Layout, /, *, c_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], is_lower_frag_required: Bool, cta_group: Int, num_output_warps: UInt, c_swizzle: TensorMapSwizzle](c_upper_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_lower_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_iter: LayoutTensorIter[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], c_coord: Tuple[UInt, UInt], elect_one_warp: Bool)`
--- ## promote_accumulators (Warp_specialized_blockwise_fp8)
`promote_accumulators[pipeline_stages: UInt, num_accum_pipeline_stages: UInt, accum_type: DType, accum_layout: Layout, a_scales_type: DType, b_scales_type: DType, b_scales_layout: Layout, a_scales_smem_layout: Layout, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int, CLUSTER_SIZE: Int32, is_lower_frag_required: Bool, num_output_warps: UInt](b_scales: LayoutTensor[b_scales_type, b_scales_layout, MutAnyOrigin], a_scales_smem_iter: LayoutTensorIter[a_scales_type, a_scales_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], c_upper_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_lower_main_tile: LayoutTensor[accum_type, accum_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mma_output_pipeline: ProducerConsumerPipeline[Int(num_accum_pipeline_stages)], tmem_addr: UInt32, load_mma_pipeline: ProducerConsumerPipeline[Int(pipeline_stages)], work_tile_coord: Tuple[UInt, UInt], elect_one_warp: Bool, stage_stride_cols: UInt, k_iter: UInt, problem_shape: StaticTuple[Int32, 3])`
--- ## sm100_warp_specialized_blockwise_fp8
`sm100_warp_specialized_blockwise_fp8[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, a_scales_layout: Layout, b_scales_layout: Layout, a_scales_type: DType, b_scales_type: DType, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_scales: LayoutTensor[a_scales_type, a_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_scales: LayoutTensor[b_scales_type, b_scales_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## sm100_structured
SM100 Structured Matmul - Refactored with encapsulated pipeline management. This module provides the same SM100 matmul functionality as the original sm100 module, but with improved code organization: Key abstractions: * WorkIterator/SchedulerWorkIterator: Encapsulate work iteration and pipeline state * RingBuffer/OutputRingBuffer: Encapsulate producer-consumer synchronization * TileLoaderTMA: Encapsulate TMA tile loading logic * Context managers for cleaner acquire/release patterns ## Switching Implementations ### Option 1: Environment Variable (Recommended) Set `MODULAR_USE_STRUCTURED_SM100=1` to use this implementation: ```bash # Use original sm100 (default): ./bazelw run //max/kernels/test/gpu/linalg:test_matmul_sm100_smoke.mojo.test # Use sm100_structured: MODULAR_USE_STRUCTURED_SM100=1 ./bazelw run //max/kernels/test/gpu/linalg:test_matmul_sm100_smoke.mojo.test ``` ### Option 2: Direct Import ```mojo # Original: from linalg.matmul.gpu.sm100.matmul import ( blackwell_matmul_tma_umma_warp_specialized ) # Structured (this module): from linalg.matmul.gpu.sm100_structured import ( blackwell_matmul_tma_umma_warp_specialized ) ``` See DOCS/testing\_and\_switching.md for full documentation. ## Modules * [​`matmul`](./matmul/): SM100 Matmul CPU entry points - TMA setup and kernel launch wrappers. * [​`matmul_kernels`](./matmul_kernels/): SM100 Matmul Kernel Structs - GPU kernel entry points and helpers. * [​`matmul_output`](./matmul_output/): SM100 Matmul Output Pipeline - TMEM → SMEM → GMEM epilogue. * [​`pipeline`](./pipeline/): * [​`ring_buffer`](./ring_buffer/): Ring buffer for SM100 producer-consumer synchronization. * [​`tile_loader`](./tile_loader/): TileLoader for SM100 matrix multiplication. * [​`tile_scheduler`](./tile_scheduler/): * [​`tile_scheduler_splitk`](./tile_scheduler_splitk/): * [​`tile_writer`](./tile_writer/): TileWriter components for SM100 matrix multiplication epilogue.
--- ## blackwell_matmul_tma_umma_warp_specialized (Matmul)
`blackwell_matmul_tma_umma_warp_specialized[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b], elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: OptionalReg[UInt32] = None](c_device: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a_device: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_device: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## matmul (4)
SM100 Matmul CPU entry points - TMA setup and kernel launch wrappers. This module contains the CPU-side code for SM100 matrix multiplication: * TMA descriptor creation * Kernel instantiation and launch via ctx.enqueue\_function\_checked All GPU code (kernel structs, runtime functions) is in matmul\_kernels.mojo. ## Functions * [​`blackwell_matmul_tma_umma_warp_specialized`](./blackwell_matmul_tma_umma_warp_specialized): * [​`matmul_sm100_fallback`](./matmul_sm100_fallback):
--- ## matmul_sm100_fallback (Matmul)
`matmul_sm100_fallback[a_layout: Layout, b_layout: Layout, c_layout: Layout, c_type: DType, a_type: DType, b_type: DType, *, transpose_b: Bool, umma_shape: IndexList[3], block_tile_shape: IndexList[3], a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, a_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, b_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)`
--- ## B200MatmulSmem (Matmul_kernels)
`struct B200MatmulSmem[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, *, config: MatmulConfig[a_type, b_type, c_type, transpose_b]]` Shared memory layout for B200 SM100 matrix multiplication kernel. This struct manages the shared memory allocation for: * Input tiles (A and B matrices) with multi-stage pipelining * Output tile (C matrix) for accumulation * Synchronization barriers for producer-consumer coordination * CLC (Cluster Launch Control) barriers and response storage * TMEM (Tensor Memory) address and deallocation barrier The memory is organized to support asynchronous TMA loads and efficient bank-conflict-free access patterns for tensor core operations. Type aliases are provided for tile types (ATile, BTile, CTile) to enable cleaner function signatures without verbose LayoutTensor declarations. ## Fields * ​a\_smem (`InlineArray[B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].AType, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].a_smem_size]`): * ​b\_smem (`InlineArray[B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BType, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].b_smem_size]`): * ​c\_smem (`InlineArray[B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].CType, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].c_smem_size]`): * ​tma\_mma\_mbars (`InlineArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_group_pipeline_stages * 2)]`): * ​accum\_mbars (`InlineArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_accum_pipeline_stages * 2)]`): * ​clc\_mbars\_full (`InlineArray[SharedMemBarrier, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages]`): * ​clc\_mbars\_empty (`InlineArray[SharedMemBarrier, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages]`): * ​clc\_throttle\_mbars (`InlineArray[SharedMemBarrier, (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages * 2)]`): * ​clc\_response (`InlineArray[UInt128, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_clc_pipeline_stages]`): * ​tmem\_dealloc\_mbar (`InlineArray[SharedMemBarrier, 1]`): * ​tmem\_addr (`InlineArray[UInt32, 1]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BM, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK, config.a_swizzle]()` ### `a_smem_size` `comptime a_smem_size = ((B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BM * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK) * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_pipeline_stages)` ### `ATile` `comptime ATile = LayoutTensor[a_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=NVIDIASharedMemoryBasePtr.alignment]` ### `ATileArray` `comptime ATileArray = SMemTileArrayType[a_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].a_smem_layout, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_pipeline_stages, 128]` ### `AType` `comptime AType = Scalar[a_type]` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BN, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BN, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK, config.b_swizzle]()` ### `b_smem_size` `comptime b_smem_size = ((B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BN * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].BK) * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_pipeline_stages)` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTile` `comptime BTile = LayoutTensor[b_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=NVIDIASharedMemoryBasePtr.alignment]` ### `BTileArray` `comptime BTileArray = SMemTileArrayType[b_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].b_smem_layout, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_pipeline_stages, 128]` ### `BType` `comptime BType = Scalar[b_type]` ### `c_smem_layout` `comptime c_smem_layout = Layout.row_major(B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputM, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputN)` ### `c_smem_size` `comptime c_smem_size = ((B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputM * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].OutputN) * B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_output_stages)` ### `CTile` `comptime CTile = LayoutTensor[c_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=NVIDIASharedMemoryBasePtr.alignment]` ### `CTileArray` `comptime CTileArray = SMemTileArrayType[c_type, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].c_smem_layout, B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_output_stages, 128]` ### `CType` `comptime CType = Scalar[c_type]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = Int(config)` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = Int(config)` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config].num_pipeline_stages // Int(config))` ### `num_output_stages` `comptime num_output_stages = Int(config)` ### `num_pipeline_stages` `comptime num_pipeline_stages = Int(config)` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `SMM` `comptime SMM = SharedMemoryManager[NVIDIASharedMemoryBasePtr]` ## Methods ### `ab_pipeline_size` `static ab_pipeline_size() -> Int` Calculate the total size of A+B tiles for all pipeline stages. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `c_output_size` `static c_output_size() -> Int` Calculate the size of C tiles for all output stages. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `total_tile_size` `static total_tile_size() -> Int` Calculate the total tile storage size (A+B+C). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## BlackwellMatmulSM100FallbackKernel
`struct BlackwellMatmulSM100FallbackKernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, block_tile_shape: IndexList[3], mma_shape: IndexList[3], transpose_b: Bool = True, cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1, 1, 1), a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, num_threads: UInt = 128, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None]` Simple fallback matmul kernel for SM100 (B200). This kernel is used when the warp-specialized kernel is not applicable, such as for small problem sizes or unsupported configurations. Unlike the main BlackwellMatmulSM100Kernel, this uses: * Single warp approach (no warp specialization) * Basic barrier synchronization (no CLC scheduling) * Direct LayoutTensor output (no TMA for C) * Simpler pipeline with single buffer ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_size` `comptime a_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout.size()` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, a_swizzle]()` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ### `ATile` `comptime ATile = LayoutTensor[a_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]` ### `b_size` `comptime b_size = BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout.size()` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK, b_swizzle]()` ### `BK` `comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `BTile` `comptime BTile = LayoutTensor[b_type, BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]` ### `c_frag_size` `comptime c_frag_size = ((BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M * BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N) // Int(num_threads))` ### `max_tmem_cols` `comptime max_tmem_cols = 512` ### `MMA_K` `comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)` ### `num_k_mmas` `comptime num_k_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BK // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_K)` ### `num_m_mmas` `comptime num_m_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BM // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_M)` ### `num_n_mmas` `comptime num_n_mmas = (BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].BN // BlackwellMatmulSM100FallbackKernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, block_tile_shape, mma_shape, transpose_b, cluster_shape, a_swizzle, b_swizzle, num_threads, elementwise_lambda_fn].MMA_N)` ## Methods ### `validate_constraints` `static validate_constraints()` Validate compile-time constraints for this kernel configuration. ### `run` `static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], num_iters: UInt)` Run the fallback matmul kernel. **Args:** * ​a\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix A. * ​b\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix B. * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor C (LayoutTensor, not TMA). * ​num\_iters ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Number of K-dimension iterations.
--- ## BlackwellMatmulSM100Kernel
`struct BlackwellMatmulSM100Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], cluster_shape: StaticTuple[Int32, 3] = StaticTuple[Int32, 3](1), elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, pdl_level: PDLLevel = PDLLevel(), max_profiled_tiles_per_SM: UInt32 = 0]` Blackwell SM100 GEMM kernel with warp specialization. This struct unifies all parameters and derived types for the SM100 matmul kernel, providing: * Compile-time parameter validation * Centralized derived type computation * Factory methods for kernel components * Multiple kernel entry points (standard, split-k) The SM100 kernel uses: * Tensor Memory (TMEM) for MMA accumulators * Cluster Launch Control (CLC) for dynamic tile scheduling * Warp specialization: Scheduler, TMA Load, MMA, Epilogue warps * Software pipelining for overlapping compute and memory operations ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.a_swizzle]()` ### `accum_pipeline_consumer_arv_count` `comptime accum_pipeline_consumer_arv_count = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)` ### `accum_pipeline_producer_arv_count` `comptime accum_pipeline_producer_arv_count = 1` ### `accum_type` `comptime accum_type = MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]() if transpose_b else tile_layout_mn_major[b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, config.b_swizzle]()` ### `BK` `comptime BK = config.block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = config.block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = config.block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `clc_consumer_arv_count` `comptime clc_consumer_arv_count = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_SIZE * ((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)))` ### `clc_producer_arv_count` `comptime clc_producer_arv_count = 1` ### `clc_throttle_consumer_arv_count` `comptime clc_throttle_consumer_arv_count = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS` ### `clc_throttle_producer_arv_count` `comptime clc_throttle_producer_arv_count = BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS` ### `CLUSTER_M` `comptime CLUSTER_M = Int.__init__[Int](config.cluster_shape.__getitem__[3, DType.int64, Int](0))` ### `CLUSTER_N` `comptime CLUSTER_N = Int.__init__[Int](config.cluster_shape.__getitem__[3, DType.int64, Int](1))` ### `CLUSTER_SIZE` `comptime CLUSTER_SIZE = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M * BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N)` ### `Context` `comptime Context = KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N]` ### `cta_group` `comptime cta_group = config.cta_group` ### `EPILOGUE_THREADS` `comptime EPILOGUE_THREADS = (4 * WARP_SIZE)` ### `EpilogueConf` `comptime EpilogueConf = EpilogueConfig[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.c_smem_layout.shape[1].value(), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, False]` ### `MMA_K` `comptime MMA_K = config.mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = config.mma_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = config.mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MMA_THREADS` `comptime MMA_THREADS = WARP_SIZE` ### `MmaOp` `comptime MmaOp = MmaOpSM100_SS[c_type, a_type, b_type, config.block_tile_shape, config.mma_shape, accum_type=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].accum_type, cta_group=BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, cluster_shape=config.cluster_shape, a_swizzle=config.a_swizzle, b_swizzle=config.b_swizzle, transpose_b=transpose_b]` ### `num_accum_pipeline_stages` `comptime num_accum_pipeline_stages = Int(config)` ### `num_clc_pipeline_stages` `comptime num_clc_pipeline_stages = Int(config)` ### `num_group_pipeline_stages` `comptime num_group_pipeline_stages = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_pipeline_stages // Int(config))` ### `num_k_mmas` `comptime num_k_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_K)` ### `num_m_mmas` `comptime num_m_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_M // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group))` ### `num_n_mmas` `comptime num_n_mmas = (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN // (BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group))` ### `num_output_stages` `comptime num_output_stages = Int(config)` ### `num_output_warps` `comptime num_output_warps = 4` ### `num_pipeline_stages` `comptime num_pipeline_stages = Int(config)` ### `NUM_THREADS` `comptime NUM_THREADS = (((BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SCHEDULER_THREADS + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].TMA_LOAD_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_THREADS) + BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].EPILOGUE_THREADS)` ### `NUM_TMEM_COLS` `comptime NUM_TMEM_COLS = 512` ### `OutputM` `comptime OutputM = config.output_tile_shape.__getitem__[2, DType.int64, Int](0)` ### `OutputN` `comptime OutputN = config.output_tile_shape.__getitem__[2, DType.int64, Int](1)` ### `OutputRB` `comptime OutputRB = OutputRingBuffer[Int(config), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].stage_stride_cols, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group]` ### `RingBuffer` `comptime RingBuffer = RingBuffer[a_type, b_type, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.a_smem_layout, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.b_smem_layout, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages, Int(config)]` ### `Scheduler` `comptime Scheduler = TileScheduler[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, Index[dtype=DType.uint32](config.cluster_shape.__getitem__[3, DType.int64, Int](0), config.cluster_shape.__getitem__[3, DType.int64, Int](1), config.cluster_shape.__getitem__[3, DType.int64, Int](2)), config.raster_order, config.block_swizzle_size]` ### `SCHEDULER_THREADS` `comptime SCHEDULER_THREADS = WARP_SIZE` ### `SmemType` `comptime SmemType = B200MatmulSmem[a_type, b_type, c_type, transpose_b, config=config]` ### `stage_stride_cols` `comptime stage_stride_cols = (512 // BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_accum_pipeline_stages)` ### `TileLoaderTMA` `comptime TileLoaderTMA = TileLoaderTMA[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.a_smem_layout, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.b_smem_layout, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BM, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BN, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].BK, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].MMA_N, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, Int(config), BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].SmemType.num_group_pipeline_stages]` ### `TMA_LOAD_THREADS` `comptime TMA_LOAD_THREADS = WARP_SIZE` ## Methods ### `validate_constraints` `static validate_constraints()` Validate parameter constraints at compile time. ### `init_barriers` `static init_barriers(ctx: KernelContext[BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].num_clc_pipeline_stages, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].cta_group, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_M, BlackwellMatmulSM100Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, a_desc_layout, b_desc_layout, c_desc_layout, transpose_b, config, cluster_shape, elementwise_compute_lambda_fn, register_based_epilogue, pdl_level, max_profiled_tiles_per_SM].CLUSTER_N], a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], tma_mma_mbars_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], accum_mbars_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_throttle_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_full_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], clc_empty_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tmem_dealloc_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED])` Initialize barriers and prefetch TMA descriptors. Called by elect\_one\_warp && elect\_one\_thread. ### `mma` `static mma(tmem_addr: UInt32, tiles: ConsumerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)` Execute MMA operations for one pipeline stage. This is the core MMA function designed to be called within a consumer tiles context: ``` with consumer.get_tiles() as tiles: Self.mma(tmem_addr, tiles, mma_op, ...) ``` **Args:** * ​tmem\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Tensor memory address for accumulators. * ​tiles ([`ConsumerTiles`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/ring_buffer/ConsumerTiles)): ConsumerTiles context with stage, mbar, and tile arrays. * ​mma\_op ([`MmaOpSM100_SS`](/mojo/kernels/linalg/arch/sm100/mma/MmaOpSM100_SS)): The MMA operation instance. * ​elect\_one\_warp ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether this warp should execute. * ​iter\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): K iteration index. * ​k\_start ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Starting K iteration (for init\_c determination). ### `run` `static run(a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])` Main kernel entry point for SM100 matrix multiplication. ### `run_splitk` `static run_splitk[reduction_layout: Layout](a_tma_op: TMATensorTile[a_type, a_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], reduction_tensor: LayoutTensor[MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type, reduction_layout, MutAnyOrigin], lock_ptr: LegacyUnsafePointer[UInt8], cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], workspace: Span[UInt64, MutAnyOrigin])` Split-K kernel entry point for better parallelism on small problems. Split-K divides the K dimension across multiple CTAs, with each CTA computing a partial result that is then reduced. **Args:** * ​a\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix A. * ​b\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix B. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix C. * ​reduction\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Workspace for partial results from each split. * ​lock\_ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Synchronization locks for reduction coordination. * ​cluster\_dim ([`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)): Cluster dimensions. * ​mnk ([`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)): Problem dimensions (M, N, K). * ​workspace ([`Span`](/mojo/stdlib/memory/span/Span)): Workspace buffer for profiling/scheduling.
--- ## KernelContext
`struct KernelContext[num_clc_pipeline_stages: Int, cta_group: Int, CLUSTER_M: Int, CLUSTER_N: Int]` Shared kernel state: election vars, CTA coords, multicast masks, pipeline states. ## Fields * ​elect\_one\_warp (`Bool`): * ​elect\_one\_thread (`Bool`): * ​elect\_one\_cta (`Bool`): * ​is\_first\_cta\_in\_cluster (`Bool`): * ​warp\_id (`UInt32`): * ​rank\_m (`UInt`): * ​rank\_n (`UInt`): * ​peer\_cta\_coord (`Tuple[UInt, UInt, UInt]`): * ​a\_multicast\_mask (`UInt16`): * ​b\_multicast\_mask (`UInt16`): * ​mma\_complete\_mask (`Int`): * ​ptr\_tmem\_addr (`LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `__init__(out self, ptr_tmem_addr: LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED])` Initialize context from TMEM pointer; computes all derived state.
--- ## WarpRole (Matmul_kernels)
`@register_passable(trivial)` `struct WarpRole` Warp role identifiers for SM100 warp-specialized kernel. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Epilogue` `comptime Epilogue = WarpRole(3)` ### `MainLoad` `comptime MainLoad = WarpRole(5)` ### `Mma` `comptime Mma = WarpRole(6)` ### `Scheduler` `comptime Scheduler = WarpRole(4)` ## Methods ### `__eq__` `__eq__(self, other: UInt) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ge__` `__ge__(self, other: UInt) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_main_load` `static is_main_load() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_mma` `static is_mma() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_epilogue` `static is_epilogue() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_scheduler` `static is_scheduler() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## consumer_main_loop (Matmul_kernels)
`consumer_main_loop[accum_type: DType, c_type: DType, a_type: DType, b_type: DType, a_smem_layout: Layout, b_smem_layout: Layout, a_swizzle: TensorMapSwizzle, b_swizzle: TensorMapSwizzle, transpose_b: Bool, pipeline_stages: Int, /, *, block_tile_shape: IndexList[3], mma_shape: IndexList[3], cta_group: Int = 1, cluster_shape: IndexList[3] = Index(1, 1, 1), k_group_size: UInt = 1](tmem_addr: UInt32, a_smem_iter: LayoutTensorIter[a_type, a_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], load_mma_pipeline: ProducerConsumerPipeline[pipeline_stages], mma_op: MmaOpSM100_SS[c_type, a_type, b_type, block_tile_shape, mma_shape, accum_type=accum_type, cta_group=cta_group, cluster_shape=cluster_shape, a_swizzle=a_swizzle, b_swizzle=b_swizzle, transpose_b=transpose_b], elect_one_warp: Bool, iter_idx: UInt32, k_start: UInt32)` Consume tiles from shared memory and execute MMA operations. This is the public API for external callers using SMemTileIter.
--- ## f32_frag_to_smem (Matmul_kernels)
`f32_frag_to_smem[swizzle_mode: TensorMapSwizzle, stageN: UInt](vec: SIMD[dtype, size], dst: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## matmul_kernels
SM100 Matmul Kernel Structs - GPU kernel entry points and helpers. This module contains the GPU kernel structs for SM100 matmul: * WarpRole: Warp specialization roles (MMA, Load, Scheduler, Epilogue) * KernelContext: Common kernel state (election vars, CTA coords, masks) * B200MatmulSmem: Shared memory layout for the kernel * BlackwellMatmulSM100Kernel: Main kernel struct with run() and run\_splitk() * BlackwellMatmulSM100FallbackKernel: Simple fallback kernel * consumer\_main\_loop: MMA consumer loop (for external callers) Output pipeline functions (copy\_accum\_to\_gmem, multi\_stage\_store\_C) are in matmul\_output.mojo. The kernel implements a warp-specialized architecture: * Scheduler warp: CLC-based tile scheduling * TMA Load warp: Async memory transfers * MMA warp: Tensor core operations with TMEM accumulators * Epilogue warps: Output from TMEM to GMEM (see matmul\_output.mojo) ## `comptime` values ### `RLayout32Bits` `comptime RLayout32Bits[layout: Layout] = RuntimeLayout[layout, element_type=DType.uint32, linear_idx_type=DType.uint32]` #### Parameters * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Structs * [​`B200MatmulSmem`](./B200MatmulSmem): Shared memory layout for B200 SM100 matrix multiplication kernel. * [​`BlackwellMatmulSM100FallbackKernel`](./BlackwellMatmulSM100FallbackKernel): Simple fallback matmul kernel for SM100 (B200). * [​`BlackwellMatmulSM100Kernel`](./BlackwellMatmulSM100Kernel): Blackwell SM100 GEMM kernel with warp specialization. * [​`KernelContext`](./KernelContext): Shared kernel state: election vars, CTA coords, multicast masks, pipeline states. * [​`WarpRole`](./WarpRole): Warp role identifiers for SM100 warp-specialized kernel. ## Functions * [​`consumer_main_loop`](./consumer_main_loop): Consume tiles from shared memory and execute MMA operations. * [​`f32_frag_to_smem`](./f32_frag_to_smem): * [​`stsm_helper`](./stsm_helper): Store a fragment to shared memory using st.matrix.
--- ## stsm_helper (Matmul_kernels)
`stsm_helper[swizzle: Swizzle, stageN: UInt, transpose_c: Bool = False, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B](vec: SIMD[dtype, size], dst: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_offset: UInt32 = 0)` Store a fragment to shared memory using st.matrix. Delegates to store\_fragment\_to\_smem for non-float32 types, and to f32\_frag\_to\_smem for float32.
--- ## accum_arrive (Matmul_output)
`accum_arrive[cta_group: Int, num_stages: Int](stage: OutputStage[num_stages])` Signal accumulator arrival. Delegates to AccumBarrier. **Args:** * ​stage (`OutputStage`): OutputStage containing pipeline and stage index.
--- ## copy_accum_to_gmem (Matmul_output)
`copy_accum_to_gmem[c_type: DType, c_layout: Layout, c_smem_layout: Layout, c_desc_layout: Layout, num_accum_pipeline_stages: Int, num_output_stages: Int, /, *, repeat: Int, accum_type: DType, cta_group: Int, epilogue_dtype: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], num_output_warps: UInt, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, transpose_c: Bool = False](c_tiles: SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], output_stage: OutputStage[num_accum_pipeline_stages], c_coord: Tuple[UInt32, UInt32], c_shape: Tuple[UInt32, UInt32])` Epilogue pipeline: TMEM → Registers → SMEM → GMEM (via TMA). **Args:** * ​c\_tiles ([`SMemTileArrayType`](/mojo/kernels/linalg/structuring/SMemTileArrayType)): Shared memory tiles for output staging. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for C matrix. * ​output\_stage (`OutputStage`): Self-contained stage with pipeline, stage index, and TMEM offset. * ​c\_coord ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): (M, N) tile coordinates. * ​c\_shape ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): (M, N) matrix dimensions.
--- ## matmul_output
SM100 Matmul Output Pipeline - TMEM → SMEM → GMEM epilogue. This module contains the output pipeline code for SM100 matmul: * copy\_accum\_to\_gmem: Core epilogue pipeline (TMEM → Registers → SMEM → GMEM) * multi\_stage\_store\_C: Output pipeline orchestration for standard matmul * multi\_stage\_store\_C\_split\_k: Output pipeline for split-K matmul The output pipeline handles: * Loading accumulated results from Tensor Memory (TMEM) * Applying optional epilogue operations (bias, activation) * Writing to shared memory via st.matrix instructions * Transferring to global memory via TMA async stores ## Functions * [​`accum_arrive`](./accum_arrive): Signal accumulator arrival. Delegates to AccumBarrier. * [​`copy_accum_to_gmem`](./copy_accum_to_gmem): Epilogue pipeline: TMEM → Registers → SMEM → GMEM (via TMA). * [​`multi_stage_store_C`](./multi_stage_store_C): Orchestrate output from TMEM to GMEM via shared memory. * [​`multi_stage_store_C_split_k`](./multi_stage_store_C_split_k): Split-K output pipeline with reduction.
--- ## multi_stage_store_C (Matmul_output)
`multi_stage_store_C[c_type: DType, c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, num_accum_pipeline_stages: Int, num_output_stages: Int, /, *, input_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], stage_stride_cols: UInt, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 1, num_output_warps: UInt = 4, max_tmem_cols: UInt = 512, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, transpose_c: Bool = False](c_tiles: SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], stage: OutputStage[num_accum_pipeline_stages], work_tile_coord: Tuple[UInt32, UInt32], elect_one_warp: Bool, M: UInt32, N: UInt32)` Orchestrate output from TMEM to GMEM via shared memory. **Args:** * ​c\_tiles ([`SMemTileArrayType`](/mojo/kernels/linalg/structuring/SMemTileArrayType)): Shared memory tiles for output staging. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for C matrix. * ​stage (`OutputStage`): Self-contained output stage with pipeline, stage index, and TMEM offset. * ​work\_tile\_coord ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): (M, N) tile coordinates. * ​elect\_one\_warp ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether this warp is elected. * ​M ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Matrix M dimension. * ​N ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Matrix N dimension.
--- ## multi_stage_store_C_split_k (Matmul_output)
`multi_stage_store_C_split_k[c_type: DType, c_smem_layout: Layout, c_layout: Layout, c_desc_layout: Layout, reduction_layout: Layout, num_accum_pipeline_stages: Int, num_output_stages: Int, /, *, input_type: DType, accum_type: DType, block_tile_shape: IndexList[3], mma_shape: IndexList[3], stage_stride_cols: UInt, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, cta_group: Int = 1, num_output_warps: UInt = 4, max_tmem_cols: UInt = 512, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, register_based_epilogue: Bool = True, transpose_c: Bool = False](scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], reduction_tensor: LayoutTensor[accum_type, reduction_layout, MutAnyOrigin], c_tiles: SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], stage: OutputStage[num_accum_pipeline_stages], work_info: WorkInfo, elect_one_warp: Bool, M: UInt32, N: UInt32)` Split-K output pipeline with reduction. **Args:** * ​scheduler ([`TileScheduler`](/mojo/kernels/linalg/grouped_matmul_tile_scheduler/TileScheduler)): Split-K tile scheduler for reduction. * ​reduction\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor for accumulating partial results. * ​c\_tiles ([`SMemTileArrayType`](/mojo/kernels/linalg/structuring/SMemTileArrayType)): Shared memory tiles for output staging. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for C matrix. * ​stage (`OutputStage`): Self-contained output stage with pipeline, stage index, and TMEM offset. * ​work\_info ([`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo)): Current work item info. * ​elect\_one\_warp ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether this warp is elected. * ​M ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Matrix M dimension. * ​N ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Matrix N dimension.
--- ## ProducerConsumerPipeline (Pipeline)
`@register_passable(trivial)` `struct ProducerConsumerPipeline[num_stages: Int]` A producer-consumer pipeline using shared memory barriers to enforce synchronization (between producer and consumer warps). This struct is commonly used with warp specialization to pipeline operations between two warps/warpgroups with data dependencies. ## Parameters * ​num\_stages ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of pipeline stages. ## Fields * ​full (`MbarPtr`): * ​empty (`MbarPtr`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` Initialize the producer-consumer pipeline with default phases. **Args:** * ​ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory barriers. ### `wait_producer` `wait_producer(self)` Consumer waits for producer. ### `wait_consumer` `wait_consumer(self)` Producer waits for consumer. ### `producer_mbar` `producer_mbar(self, stage: UInt32) -> MbarPtr` Get the producer barrier for a specific stage. **Args:** * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The pipeline stage. **Returns:** `MbarPtr`: The shared memory barrier that the producer signals. ### `consumer_mbar` `consumer_mbar(self, stage: UInt32) -> MbarPtr` Get the consumer barrier for a specific stage. **Args:** * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The pipeline stage. **Returns:** `MbarPtr`: The shared memory barrier that the consumer signals. ### `producer_stage` `producer_stage(self) -> UInt32` Get the current producer stage index. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The current stage index for the producer (0 to num\_stages-1). ### `consumer_stage` `consumer_stage(self) -> UInt32` Get the current consumer stage index. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The current stage index for the consumer (0 to num\_stages-1). ### `consumer_step` `consumer_step(mut self)` Advance the consumer to the next pipeline stage. Increments the consumer stage and wraps to 0 when reaching num\_stages, toggling the phase bit on wrap-around. Only switch phase at end of pipeline because we assume all barriers are at the same consumer/producer phase before checked. Once checked, the execution moves to next barrier. ### `producer_step` `producer_step(mut self)` Advance the producer to the next pipeline stage. Increments the producer stage and wraps to 0 when reaching num\_stages, toggling the phase bit on wrap-around. ### `smem_bytes` `static smem_bytes() -> UInt32` Calculate the shared memory bytes required for pipeline barriers. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The total number of bytes needed for all pipeline barriers (2 \* num\_stages barriers). ### `init_mbars` `init_mbars(self, producer_arrive_count: Int32, consumer_arrive_count: Int32)` Initialize the smem barriers for the producer and consumer. This function must be called by a single thread and must be called before any the pipeline object is used. **Args:** * ​producer\_arrive\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): The number of threads that will arrive at the barrier marking data as produced. * ​consumer\_arrive\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): The number of threads that will arrive at the barrier marking data as consumed. ### `producer_signal_and_step` `producer_signal_and_step(mut self)` Wait for consumer, signal production complete, and advance stage. Combines: wait\_consumer() + arrive on producer\_mbar + producer\_step() Used for CLC throttling in the Load warp. ### `consumer_signal_and_step` `consumer_signal_and_step(mut self)` Wait for producer, signal consumption complete, and advance stage. Combines: wait\_producer() + arrive on consumer\_mbar + consumer\_step() Used for CLC throttling in the Scheduler warp.
--- ## pipeline (3)
## `comptime` values ### `MbarPtr` `comptime MbarPtr = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` ## Structs * [​`ProducerConsumerPipeline`](./ProducerConsumerPipeline): A producer-consumer pipeline using shared memory barriers to enforce synchronization (between producer and consumer warps).
--- ## Consumer
`@register_passable(trivial)` `struct Consumer[origin: MutOrigin, a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Consumer view with get\_tiles() API. ## Fields * ​ring\_buffer\_ptr (`Pointer[Consumer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].RingBufferType, origin]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `RingBufferType` `comptime RingBufferType = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]` ## Methods ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `get_tiles` `get_tiles(mut self) -> ConsumerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]` Get the next slot with stage, barrier, and tile arrays. Synchronization is handled internally - waits for tiles to be ready. **Returns:** [`ConsumerTiles`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/ring_buffer/ConsumerTiles)
--- ## ConsumerTiles
`@register_passable(trivial)` `struct ConsumerTiles[origin: MutOrigin, a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Context manager for consumer access with stage, barrier, and tile arrays. Provides everything needed to process k\_group tiles in a single context: \- stage: Current pipeline stage index \- mbar: Barrier for synchronization (tiles ready when context entered) \- a\_tiles: Full A tile array (index with stage \* k\_group\_size + j) \- b\_tiles: Full B tile array (index with stage \* k\_group\_size + j) ## Fields * ​ring\_buffer\_ptr (`Pointer[ConsumerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].RingBufferType, origin]`): * ​stage (`UInt32`): * ​mbar (`MbarPtr`): * ​a\_tiles (`ConsumerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray`): * ​b\_tiles (`ConsumerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATileArray` `comptime ATileArray = ConsumerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].RingBufferType.ATileArray` ### `BTileArray` `comptime BTileArray = ConsumerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].RingBufferType.BTileArray` ### `RingBufferType` `comptime RingBufferType = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]` ## Methods ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)`
--- ## OutputConsumerContext
`@register_passable(trivial)` `struct OutputConsumerContext[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Context manager for epilogue consumer access to OutputRingBuffer. Automatically calls acquire\_for\_epilogue on enter and release\_from\_epilogue on exit. Usage: with output\_rb.consumer() as stage: \# ... read from stage.tmem\_offset, write to GMEM ... \# release\_from\_epilogue called automatically ## Fields * ​ring\_buffer\_ptr (`Pointer[OutputConsumerContext[origin, num_stages, stage_stride_cols, cta_group].RingBufferType, origin]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `RingBufferType` `comptime RingBufferType = OutputRingBuffer[num_stages, stage_stride_cols, cta_group]` ### `Stage` `comptime Stage = OutputStage[num_stages]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[OutputConsumerContext[origin, num_stages, stage_stride_cols, cta_group].RingBufferType, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> OutputConsumerContext[origin, num_stages, stage_stride_cols, cta_group].Stage` **Returns:** `OutputConsumerContext` ### `__exit__` `__exit__(mut self)`
--- ## OutputProducerContext
`@register_passable(trivial)` `struct OutputProducerContext[origin: MutOrigin, num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Context manager for MMA producer access to OutputRingBuffer. Automatically calls acquire\_for\_mma on enter and release\_from\_mma on exit. Usage: with output\_rb.producer() as stage: \# ... MMA into stage.tmem\_offset ... \# release\_from\_mma called automatically ## Fields * ​ring\_buffer\_ptr (`Pointer[OutputProducerContext[origin, num_stages, stage_stride_cols, cta_group].RingBufferType, origin]`): * ​stage (`OutputProducerContext[origin, num_stages, stage_stride_cols, cta_group].Stage`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `RingBufferType` `comptime RingBufferType = OutputRingBuffer[num_stages, stage_stride_cols, cta_group]` ### `Stage` `comptime Stage = OutputStage[num_stages]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[OutputProducerContext[origin, num_stages, stage_stride_cols, cta_group].RingBufferType, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> OutputProducerContext[origin, num_stages, stage_stride_cols, cta_group].Stage` **Returns:** `OutputProducerContext` ### `__exit__` `__exit__(mut self)`
--- ## OutputRingBuffer
`@register_passable(trivial)` `struct OutputRingBuffer[num_stages: Int, stage_stride_cols: Int, cta_group: Int]` Ring buffer for MMA→Epilogue output pipeline. Manages TMEM accumulator stage synchronization between MMA warps (producer) and Epilogue warps (consumer). Unlike RingBuffer which manages SMEM tiles, this manages stage indices and computes TMEM offsets. The TMEM itself is allocated separately via tcgen05\_alloc; this struct only coordinates access to different stages within that allocation. Template Parameters: num\_stages: Number of accumulator pipeline stages. stage\_stride\_cols: TMEM column stride between stages. cta\_group: CTA group size (1 or 2) for multicast signaling. Usage: \# Initialize barriers once (elect\_one\_warp/elect\_one\_thread): OutputRingBuffer\[...].init\_barriers(storage\_ptr, prod\_cnt, cons\_cnt) ``` # Create ring buffer (each warp creates its own): var output_rb = OutputRingBuffer[...](storage_ptr, tmem_addr, mask) # MMA warp (producer): with output_rb.producer() as stage: # ... perform MMA into stage.tmem_offset ... # Epilogue warp (consumer): with output_rb.consumer() as stage: # ... read from stage.tmem_offset, write to GMEM ... ``` ## Fields * ​pipeline (`OutputRingBuffer[num_stages, stage_stride_cols, cta_group].Pipeline`): * ​tmem\_base\_addr (`UInt32`): * ​mma\_complete\_mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Pipeline` `comptime Pipeline = ProducerConsumerPipeline[num_stages]` ### `Stage` `comptime Stage = OutputStage[num_stages]` ## Methods ### `__init__` `__init__(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], tmem_base_addr: UInt32, mma_complete_mask: UInt16) -> Self` Initialize output ring buffer. Creates pipeline internally from storage pointer. Barriers must be initialized via init\_barriers() before first use. **Args:** * ​storage\_ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory barrier storage. * ​tmem\_base\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base TMEM address for accumulators. * ​mma\_complete\_mask ([`UInt16`](/mojo/stdlib/builtin/simd/#uint16)): Multicast mask for 2-SM MMA completion signaling. ### `init_barriers` `static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize pipeline barriers. Called once by elect\_one thread. **Args:** * ​storage\_ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory barrier storage. * ​producer\_arv\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Expected arrival count for producer barriers. * ​consumer\_arv\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Expected arrival count for consumer barriers. ### `acquire_for_mma` `acquire_for_mma(self) -> OutputRingBuffer[num_stages, stage_stride_cols, cta_group].Stage` Acquire a stage for MMA computation. Waits for the epilogue to finish with this stage, then returns the stage info with computed TMEM offset and pipeline reference. **Returns:** `OutputRingBuffer`: OutputStage with stage index, TMEM offset, and pipeline for signaling. ### `release_from_mma` `release_from_mma(mut self, stage: OutputStage[num_stages])` Signal MMA completion and advance to next stage. Signals the epilogue that accumulator data is ready, using either mma\_arrive (1-SM) or mma\_arrive\_multicast (2-SM). **Args:** * ​stage ([`OutputStage`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/ring_buffer/OutputStage)): The stage being released (from acquire\_for\_mma). ### `acquire_for_epilogue` `acquire_for_epilogue(self) -> OutputRingBuffer[num_stages, stage_stride_cols, cta_group].Stage` Acquire a stage for epilogue processing. Waits for MMA to complete this stage, then returns the stage info. **Returns:** `OutputRingBuffer`: OutputStage with stage index, TMEM offset, and pipeline for signaling. ### `release_from_epilogue` `release_from_epilogue(mut self)` Signal epilogue completion and advance to next stage. Signals MMA that this accumulator stage is free for reuse. ### `producer` `producer[origin: MutOrigin, //](ref [origin] self) -> OutputProducerContext[origin, num_stages, stage_stride_cols, cta_group]` Get a producer context for MMA warp. Usage: with output\_rb.producer() as stage: \# MMA into stage.tmem\_offset \# release\_from\_mma called automatically **Returns:** `OutputProducerContext` ### `consumer` `consumer[origin: MutOrigin, //](ref [origin] self) -> OutputConsumerContext[origin, num_stages, stage_stride_cols, cta_group]` Get a consumer context for epilogue warp. Usage: with output\_rb.consumer() as stage: \# Read from stage.tmem\_offset, write to GMEM \# release\_from\_epilogue called automatically **Returns:** `OutputConsumerContext` ### `get_pipeline` `get_pipeline(self) -> OutputRingBuffer[num_stages, stage_stride_cols, cta_group].Pipeline` Get the underlying pipeline for barrier initialization. Note: With OutputStage now carrying the pipeline, most code no longer needs this. It's retained for init\_barriers() which needs the raw pipeline before any OutputStage instances exist. **Returns:** `OutputRingBuffer`
--- ## OutputStage
`@register_passable(trivial)` `struct OutputStage[num_stages: Int]` Stage info for output pipeline. Contains the stage index, computed TMEM offset, and a copy of the pipeline. This makes the stage self-contained, eliminating the need to pass the pipeline separately to functions like multi\_stage\_store\_C. Template Parameters: num\_stages: Number of pipeline stages (must match the OutputRingBuffer). ## Fields * ​stage (`UInt32`): * ​tmem\_offset (`UInt32`): * ​pipeline (`OutputStage[num_stages].Pipeline`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Pipeline` `comptime Pipeline = ProducerConsumerPipeline[num_stages]` ## Methods ### `__init__` `__init__(stage: UInt32, tmem_offset: UInt32, pipeline: ProducerConsumerPipeline[num_stages]) -> Self`
--- ## Producer
`@register_passable(trivial)` `struct Producer[origin: MutOrigin, a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Producer view with get\_tiles() API. ## Fields * ​ring\_buffer\_ptr (`Pointer[Producer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].RingBufferType, origin]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `RingBufferType` `comptime RingBufferType = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]` ## Methods ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)` ### `drain` `drain(mut self)` Drain all pending pipeline stages. Prevents the CTA from exiting while a peer CTA is still working on MMA. Waits for each consumer slot and releases it, cycling through all num\_group\_stages stages. ### `get_tiles` `get_tiles(mut self) -> ProducerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]` Get the next available slot with stage, barrier, and tile arrays. Synchronization is handled internally - waits for slot availability. **Returns:** `ProducerTiles`
--- ## ProducerTiles
`@register_passable(trivial)` `struct ProducerTiles[origin: MutOrigin, a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Context manager for producer access with stage, barrier, and tile arrays. Provides everything needed to load k\_group tiles in a single context: \- stage: Current pipeline stage index \- barrier: Barrier for synchronization (call expect\_bytes once for all k\_group) \- a\_tiles: Full A tile array (index with stage \* k\_group\_size + j) \- b\_tiles: Full B tile array (index with stage \* k\_group\_size + j) ## Fields * ​ring\_buffer\_ptr (`Pointer[ProducerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].RingBufferType, origin]`): * ​stage (`UInt32`): * ​barrier (`MbarPtr`): * ​a\_tiles (`ProducerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray`): * ​b\_tiles (`ProducerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATileArray` `comptime ATileArray = ProducerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].RingBufferType.ATileArray` ### `BTileArray` `comptime BTileArray = ProducerTiles[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].RingBufferType.BTileArray` ### `RingBufferType` `comptime RingBufferType = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]` ## Methods ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)`
--- ## RingBuffer (Ring_buffer)
`@register_passable(trivial)` `struct RingBuffer[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_group_stages: Int, k_group_size: Int]` Ring buffer with tile storage for SM100 producer-consumer sync. This is the SM90-style API where tiles are stored in the ring buffer and returned directly from get\_tiles(). Template Parameters: a\_type: Data type for A matrix tiles. b\_type: Data type for B matrix tiles. a\_tile\_layout: Memory layout for A tiles. b\_tile\_layout: Memory layout for B tiles. num\_pipeline\_stages: Total number of tile stages. num\_group\_stages: Number of synchronization stages. k\_group\_size: Number of tiles per synchronization stage. ## Fields * ​pipeline (`RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].Pipeline`): * ​a\_tiles (`RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray`): * ​b\_tiles (`RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray.Tile` ### `ATileArray` `comptime ATileArray = SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128]` ### `BTile` `comptime BTile = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray.Tile` ### `BTileArray` `comptime BTileArray = SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]` ### `Pipeline` `comptime Pipeline = ProducerConsumerPipeline[num_group_stages]` ## Methods ### `__init__` `__init__(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], a_tiles: SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, 128], b_tiles: SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, 128]) -> Self` Initialize ring buffer from storage pointer. Creates pipeline internally from storage pointer. Barriers must be initialized via init\_barriers() before first use. **Args:** * ​storage\_ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory barrier storage. * ​a\_tiles ([`SMemTileArrayType`](/mojo/kernels/linalg/structuring/SMemTileArrayType)): A matrix tile array in shared memory. * ​b\_tiles ([`SMemTileArrayType`](/mojo/kernels/linalg/structuring/SMemTileArrayType)): B matrix tile array in shared memory. ### `init_barriers` `static init_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize pipeline barriers. Called once by elect\_one thread. **Args:** * ​storage\_ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory barrier storage. * ​producer\_arv\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Expected arrival count for producer barriers. * ​consumer\_arv\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Expected arrival count for consumer barriers. ### `producer` `producer[origin: MutOrigin](ref [origin] self) -> Producer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]` Get producer view with get\_tiles() API. **Returns:** `Producer` ### `consumer` `consumer[origin: MutOrigin](ref [origin] self) -> Consumer[origin, a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size]` Get consumer view with get\_tiles() API. **Returns:** `Consumer` ### `get_producer_tiles` `get_producer_tiles(mut self) -> Tuple[UInt32, MbarPtr, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray]` Wait for slot and return stage, barrier, and tile arrays. Synchronization is handled internally - waits for consumer to release slot. Use stage to index: tiles.a\_tiles\[stage \* k\_group\_size + j] **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (stage, barrier, a\_tiles, b\_tiles). ### `enqueue_tile` `enqueue_tile(mut self)` Signal producer finished loading and advance stage. ### `get_tile` `get_tile[tile_idx_in_group: Int](self, stage: UInt32) -> Tuple[RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATile, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTile]` Get tiles at specific index within the current k\_group. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `get_consumer_tiles` `get_consumer_tiles(mut self) -> Tuple[UInt32, MbarPtr, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].ATileArray, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_group_stages, k_group_size].BTileArray]` Wait for slot and return stage, barrier, and tile arrays. Synchronization is handled internally - waits for producer to fill slot. Use stage to index: tiles.a\_tiles\[stage \* k\_group\_size + j] **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (stage, mbar, a\_tiles, b\_tiles). ### `release_slot` `release_slot(mut self)` Signal consumer finished and advance stage. ### `producer_stage` `producer_stage(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `consumer_stage` `consumer_stage(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `producer_mbar` `producer_mbar(self, stage: UInt32) -> MbarPtr` **Returns:** `MbarPtr` ### `consumer_mbar` `consumer_mbar(self, stage: UInt32) -> MbarPtr` **Returns:** `MbarPtr`
--- ## ring_buffer (Ring_buffer)
Ring buffer for SM100 producer-consumer synchronization. Provides SM90-style get\_tiles() API for TMA-MMA pipeline synchronization. Usage: var ring\_buffer = RingBuffer\[...]\(pipeline, a\_tiles, b\_tiles) ``` # Producer: tiles contains stage, barrier, a_tiles, b_tiles with ring_buffer.producer() as producer: with producer.get_tiles() as tiles: load_tiles(tiles) # Access tiles.a_tiles[stage * k_group + j] # Consumer: tiles contains stage, mbar, a_tiles, b_tiles with ring_buffer.consumer() as consumer: with consumer.get_tiles() as tiles: mma_tiles(tiles) # Access tiles.a_tiles[stage * k_group + j] ``` ## `comptime` values ### `MbarPtr` `comptime MbarPtr = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` ## Structs * [​`Consumer`](./Consumer): Consumer view with get\_tiles() API. * [​`ConsumerTiles`](./ConsumerTiles): Context manager for consumer access with stage, barrier, and tile arrays. * [​`OutputConsumerContext`](./OutputConsumerContext): Context manager for epilogue consumer access to OutputRingBuffer. * [​`OutputProducerContext`](./OutputProducerContext): Context manager for MMA producer access to OutputRingBuffer. * [​`OutputRingBuffer`](./OutputRingBuffer): Ring buffer for MMA→Epilogue output pipeline. * [​`OutputStage`](./OutputStage): Stage info for output pipeline. * [​`Producer`](./Producer): Producer view with get\_tiles() API. * [​`ProducerTiles`](./ProducerTiles): Context manager for producer access with stage, barrier, and tile arrays. * [​`RingBuffer`](./RingBuffer): Ring buffer with tile storage for SM100 producer-consumer sync.
--- ## TileLoaderTMA
`@register_passable(trivial)` `struct TileLoaderTMA[a_tma_origin: ImmutOrigin, b_tma_origin: ImmutOrigin, a_type: DType, b_type: DType, a_layout: Layout, b_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, //, a_smem_layout: Layout, b_smem_layout: Layout, BM: Int, BN: Int, BK: Int, MMA_N: Int, cta_group: Int, k_group_size: Int, num_pipeline_stages: Int, num_group_stages: Int]` TMA-based tile loader for SM100. Encapsulates the complete tile loading logic including: * K-group batching (multiple tiles per barrier) * CTA group coordination (1-SM or 2-SM cooperative loading) * Peer CTA slicing for 2-SM MMA * expect\_bytes management Template Parameters: a\_tma\_origin: Origin type for A TMA pointer. b\_tma\_origin: Origin type for B TMA pointer. a\_type: Data type for A matrix. b\_type: Data type for B matrix. a\_layout: Global memory layout for A. b\_layout: Global memory layout for B. a\_desc\_layout: TMA descriptor layout for A. b\_desc\_layout: TMA descriptor layout for B. a\_smem\_layout: Shared memory tile layout for A. b\_smem\_layout: Shared memory tile layout for B. BM: Block tile M dimension. BN: Block tile N dimension. BK: Block tile K dimension. MMA\_N: MMA N dimension for B coordinate calculation. cta\_group: Number of CTAs cooperating, 1 or 2. k\_group\_size: Number of K tiles per barrier sync. num\_pipeline\_stages: Total pipeline stages. num\_group\_stages: Pipeline stages / k\_group\_size. ## Fields * ​a\_tma\_op (`TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOpPtr`): * ​b\_tma\_op (`TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOpPtr`): * ​a\_multicast\_mask (`UInt16`): * ​b\_multicast\_mask (`UInt16`): * ​peer\_rank\_n (`UInt`): * ​peer\_rank\_m (`UInt`): * ​peer\_m\_rank (`UInt`): * ​work\_m\_coord (`UInt`): * ​work\_n\_coord (`UInt`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_expected_bytes` `comptime a_expected_bytes = (a_smem_layout.size() * size_of[a_type]())` ### `a_tma_load_size` `comptime a_tma_load_size = a_desc_layout.size()` ### `a_tma_rows` `comptime a_tma_rows = a_desc_layout.shape[0].value()` ### `ATmaOp` `comptime ATmaOp = TMATensorTile[a_type, a_layout, a_desc_layout]` ### `ATmaOpPtr` `comptime ATmaOpPtr = Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOp, a_tma_origin]` ### `b_expected_bytes` `comptime b_expected_bytes = (b_smem_layout.size() * size_of[b_type]())` ### `b_tma_load_size` `comptime b_tma_load_size = b_desc_layout.size()` ### `b_tma_rows` `comptime b_tma_rows = b_desc_layout.shape[0].value()` ### `BTmaOp` `comptime BTmaOp = TMATensorTile[b_type, b_layout, b_desc_layout]` ### `BTmaOpPtr` `comptime BTmaOpPtr = Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOp, b_tma_origin]` ### `expected_bytes` `comptime expected_bytes = ((cta_group * (TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].a_expected_bytes + TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].b_expected_bytes)) * k_group_size)` ## Methods ### `__init__` `__init__(a_tma_op: Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].ATmaOp, a_tma_origin], b_tma_op: Pointer[TileLoaderTMA[a_smem_layout, b_smem_layout, BM, BN, BK, MMA_N, cta_group, k_group_size, num_pipeline_stages, num_group_stages].BTmaOp, b_tma_origin], a_multicast_mask: UInt16, b_multicast_mask: UInt16, peer_cta_coord: Tuple[UInt, UInt, UInt]) -> Self` Initialize the TMA tile loader. **Args:** * ​a\_tma\_op ([`Pointer`](/mojo/stdlib/memory/pointer/Pointer)): Pointer to A matrix TMA descriptor. * ​b\_tma\_op ([`Pointer`](/mojo/stdlib/memory/pointer/Pointer)): Pointer to B matrix TMA descriptor. * ​a\_multicast\_mask ([`UInt16`](/mojo/stdlib/builtin/simd/#uint16)): Multicast mask for A tiles. * ​b\_multicast\_mask ([`UInt16`](/mojo/stdlib/builtin/simd/#uint16)): Multicast mask for B tiles. * ​peer\_cta\_coord ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Peer CTA coordinates (rank\_n, rank\_m, peer\_m\_rank). ### `set_work_tile` `set_work_tile(mut self, m_coord: UInt, n_coord: UInt)` Set the current output tile coordinates. **Args:** * ​m\_coord ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): M coordinate of the output tile. * ​n\_coord ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): N coordinate of the output tile. ### `load_tiles` `load_tiles[tiles_origin: MutOrigin, //](self, tiles: ProducerTiles[tiles_origin, a_type, b_type, a_smem_layout, b_smem_layout, num_pipeline_stages, num_group_stages, k_group_size], iter_idx: UInt32, elect_one_cta: Bool)` Load k\_group\_size A and B tiles using TMA. **Args:** * ​tiles ([`ProducerTiles`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/ring_buffer/ProducerTiles)): ProducerTiles context with stage, barrier, and tile arrays. * ​iter\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): K iteration index (base index, not multiplied). * ​elect\_one\_cta ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): True if this CTA should call expect\_bytes.
--- ## tile_loader
TileLoader for SM100 matrix multiplication. Provides tile loading abstractions for efficient global-to-shared memory transfers using TMA with support for: * K-group batching (multiple tiles per barrier synchronization) * CTA group coordination (1-SM or 2-SM cooperative loading) * Multicast for cluster distribution Usage: var loader = TileLoaderTMA\[...]\(a\_tma\_op, b\_tma\_op, masks, peer\_coord) loader.set\_work\_tile(m\_coord, n\_coord) ``` with producer.get_tiles() as tiles: loader.load_tiles(tiles, k_iter, elect_one_cta) ``` ## Structs * [​`TileLoaderTMA`](./TileLoaderTMA): TMA-based tile loader for SM100.
--- ## AdvanceAfterWorkContext
`@register_passable(trivial)` `struct AdvanceAfterWorkContext[work_origin: MutOrigin, state_origin: MutOrigin, num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int]` Context for warps that do work THEN advance (Load/Scheduler/Epilogue). * **enter**: Returns current work\_info for use in the block * **exit**: Fetches next work, assigns to work\_info, steps state ## Fields * ​scheduler (`AdvanceAfterWorkContext[work_origin, state_origin, num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType`): * ​work\_info\_ptr (`Pointer[WorkInfo, work_origin]`): * ​consumer\_state\_ptr (`Pointer[PipelineState[num_stages], state_origin]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size], work_info_ptr: Pointer[WorkInfo, work_origin], consumer_state_ptr: Pointer[PipelineState[num_stages], state_origin]) -> Self` ### `__enter__` `__enter__(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## PrefetchBeforeWorkContext
`@register_passable(trivial)` `struct PrefetchBeforeWorkContext[work_origin: MutOrigin]` Context for MMA warp that prefetches BEFORE work (software pipelining). * Construction: Fetches next work and steps state immediately * **enter**: Returns current work\_info for use in the block * **exit**: Assigns prefetched work to work\_info ## Fields * ​work\_info\_ptr (`Pointer[WorkInfo, work_origin]`): * ​next\_work (`WorkInfo`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(work_info_ptr: Pointer[WorkInfo, work_origin], next_work: WorkInfo) -> Self` ### `__enter__` `__enter__(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## SchedulerWorkIterator
`@register_passable(trivial)` `struct SchedulerWorkIterator[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int]` Work iterator for Scheduler warp - owns work\_info and both pipeline states. The Scheduler warp uniquely needs to: 1. Consume work responses (like other warps) via next() 2. Signal throttle and produce new work requests via signal\_and\_advance() 3. Drain pending requests at exit via drain() Usage: var sched\_iter = scheduler.scheduler\_iterator() while sched\_iter.has\_work(): with sched\_iter.next(): sched\_iter.signal\_and\_advance() sched\_iter.drain() ## Fields * ​scheduler (`SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType`): * ​work\_info (`WorkInfo`): * ​consumer\_state (`PipelineState[num_stages]`): * ​producer\_state (`PipelineState[num_stages]`): * ​throttle\_pipeline (`SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ### `ThrottlePipeline` `comptime ThrottlePipeline = SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType.ThrottlePipeline` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size], work_info: WorkInfo) -> Self` Create scheduler iterator. Throttle pipeline from scheduler. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref [state_origin] self) -> AdvanceAfterWorkContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Get next work item. **Returns:** `AdvanceAfterWorkContext` ### `signal_and_advance` `signal_and_advance(mut self)` Signal CLC throttle consumer and advance to next work request. Combines two operations that always happen together in Scheduler warp: 1. Signal throttle consumer (tells Load warp we've consumed a response) 2. Issue next CLC work request (producer side) ### `drain` `drain(mut self)` Drain all pending CLC requests before kernel exit. Must be called after the work loop completes to ensure all CLC pipeline stages are properly synchronized before exit.
--- ## TileScheduler (3)
`@register_passable(trivial)` `struct TileScheduler[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8]` ## Fields * ​cluster\_dim (`StaticTuple[Int32, 3]`): * ​log\_cluster\_dim\_m (`FastDiv[DType.uint32]`): * ​log\_cluster\_dim\_n (`FastDiv[DType.uint32]`): * ​log\_cluster\_dim\_k (`FastDiv[DType.uint32]`): * ​clc\_response (`LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]`): * ​full\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​empty\_mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​throttle\_pipeline (`TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `cluster_size` `comptime cluster_size = ((cluster_shape.__getitem__[3, DType.uint32, Int](0) * cluster_shape.__getitem__[3, DType.uint32, Int](1)) * cluster_shape.__getitem__[3, DType.uint32, Int](2))` ### `log_cluster_k` `comptime log_cluster_k = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](2))` ### `log_cluster_m` `comptime log_cluster_m = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](0))` ### `log_cluster_n` `comptime log_cluster_n = FastDiv[DType.uint32](cluster_shape.__getitem__[3, DType.uint32, Int](1))` ### `ThrottlePipeline` `comptime ThrottlePipeline = ProducerConsumerPipeline[num_stages]` ## Methods ### `__init__` `__init__(cluster_dim: StaticTuple[Int32, 3], clc_response_ptr: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], full_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], throttle_storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `init_throttle_barriers` `static init_throttle_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize throttle pipeline barriers. Called once by elect\_one thread. **Args:** * ​storage\_ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Pointer to shared memory barrier storage. * ​producer\_arv\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Expected arrival count for producer barriers. * ​consumer\_arv\_count ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Expected arrival count for consumer barriers. ### `work_info_from_clc_response` `static work_info_from_clc_response(result: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `work_info_from_cluster` `static work_info_from_cluster(work_info: WorkInfo, cluster_dim: StaticTuple[Int32, 3], log_cluster_dim_m: FastDiv[DType.uint32], log_cluster_dim_n: FastDiv[DType.uint32]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `initial_work_info` `initial_work_info(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `fetch_next_work` `fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_after_work` `advance_after_work[work_origin: MutOrigin, state_origin: MutOrigin, //](self, ref [work_origin] work_info: WorkInfo, ref [state_origin] consumer_state: PipelineState[num_stages]) -> AdvanceAfterWorkContext[work_origin, state_origin, num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Context for warps that do work THEN advance (Load/Scheduler/Epilogue). Usage: with scheduler.advance\_after\_work(work\_info, state) as current: do\_work(current) syncwarp() \# After: work\_info updated, state stepped **Returns:** `AdvanceAfterWorkContext` ### `prefetch_before_work` `prefetch_before_work[work_origin: MutOrigin, //](self, ref [work_origin] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> PrefetchBeforeWorkContext[work_origin]` Context for MMA warp that prefetches BEFORE work (software pipelining). Fetches next work and steps state IMMEDIATELY (before the with block). Usage: with scheduler.prefetch\_before\_work(work\_info, state) as current: do\_mma(current) # Uses current, not prefetched \# After: work\_info updated to prefetched value **Returns:** `PrefetchBeforeWorkContext` ### `work_iterator` `work_iterator(self) -> WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Create a per-warp work iterator with internally managed state. Each warp should create its own work iterator. The iterator owns work\_info, pipeline state, and throttle internally. Usage: var work\_iter = scheduler.work\_iterator() while work\_iter.has\_work(): with work\_iter.next() as current: work\_iter.throttle\_signal(ctx.is\_first\_cta\_in\_cluster) do\_work(current) **Returns:** `WorkIterator` ### `scheduler_iterator` `scheduler_iterator(self) -> SchedulerWorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Create iterator for Scheduler warp (owns work\_info and both pipeline states). The Scheduler warp uniquely needs to both consume work responses and produce new work requests. This iterator owns everything internally. Usage: var sched\_iter = scheduler.scheduler\_iterator() while sched\_iter.has\_work(): with sched\_iter.next(): sched\_iter.signal\_and\_advance() sched\_iter.drain() **Returns:** `SchedulerWorkIterator` ### `advance_to_next_work` `advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]` **Returns:** [`PipelineState`](/mojo/kernels/layout/tma_async/PipelineState)
--- ## WorkInfo (3)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## WorkIterator
`@register_passable(trivial)` `struct WorkIterator[num_stages: Int, cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int]` Per-warp work iterator that owns work\_info and pipeline state. Each warp creates its own WorkIterator which internally manages both the current work item and the CLC pipeline consumer state. Throttle pipeline is obtained from the scheduler. Usage: var work\_iter = scheduler.work\_iterator() while work\_iter.has\_work(): with work\_iter.next() as current: work\_iter.throttle\_signal(ctx.is\_first\_cta\_in\_cluster) do\_work(current) ## Fields * ​scheduler (`WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType`): * ​work\_info (`WorkInfo`): * ​consumer\_state (`PipelineState[num_stages]`): * ​throttle\_pipeline (`WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ### `ThrottlePipeline` `comptime ThrottlePipeline = WorkIterator[num_stages, cluster_shape, rasterize_order, block_swizzle_size].SchedulerType.ThrottlePipeline` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size], work_info: WorkInfo) -> Self` Create work iterator with initial work\_info. Throttle from scheduler. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref [state_origin] self) -> AdvanceAfterWorkContext[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, cluster_shape, rasterize_order, block_swizzle_size]` Get next work item (advance AFTER work pattern). **Returns:** `AdvanceAfterWorkContext` ### `next_prefetch` `next_prefetch[state_origin: MutOrigin, //](ref [state_origin] self) -> PrefetchBeforeWorkContext[origin_of(state_origin._mlir_origin.work_info)]` Get next work item with prefetch (advance BEFORE work pattern). **Returns:** `PrefetchBeforeWorkContext` ### `throttle_signal` `throttle_signal(mut self, is_first_cta_in_cluster: Bool)` Signal CLC throttle if this is the first CTA in cluster. The Load warp acts as producer for CLC throttle, signaling that it has started processing a new work item. This prevents the scheduler from getting too far ahead. **Args:** * ​is\_first\_cta\_in\_cluster ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only first CTA signals to avoid duplicates.
--- ## tile_scheduler (Tile_scheduler)
## Structs * [​`AdvanceAfterWorkContext`](./AdvanceAfterWorkContext): Context for warps that do work THEN advance (Load/Scheduler/Epilogue). * [​`PrefetchBeforeWorkContext`](./PrefetchBeforeWorkContext): Context for MMA warp that prefetches BEFORE work (software pipelining). * [​`SchedulerWorkIterator`](./SchedulerWorkIterator): Work iterator for Scheduler warp - owns work\_info and both pipeline states. * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo): * [​`WorkIterator`](./WorkIterator): Per-warp work iterator that owns work\_info and pipeline state.
--- ## AdvanceAfterWorkContextSplitK
`@register_passable(trivial)` `struct AdvanceAfterWorkContextSplitK[work_origin: MutOrigin, state_origin: MutOrigin, num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int, num_split_k: Int]` Context for warps that do work THEN advance (Load/Scheduler/Epilogue). ## Fields * ​scheduler (`AdvanceAfterWorkContextSplitK[work_origin, state_origin, num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType`): * ​work\_info\_ptr (`Pointer[WorkInfo, work_origin]`): * ​consumer\_state\_ptr (`Pointer[PipelineState[num_stages], state_origin]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], work_info_ptr: Pointer[WorkInfo, work_origin], consumer_state_ptr: Pointer[PipelineState[num_stages], state_origin]) -> Self` ### `__enter__` `__enter__(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## PrefetchBeforeWorkContextSplitK
`@register_passable(trivial)` `struct PrefetchBeforeWorkContextSplitK[work_origin: MutOrigin]` Context for MMA warp that prefetches BEFORE work (software pipelining). ## Fields * ​work\_info\_ptr (`Pointer[WorkInfo, work_origin]`): * ​next\_work (`WorkInfo`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(work_info_ptr: Pointer[WorkInfo, work_origin], next_work: WorkInfo) -> Self` ### `__enter__` `__enter__(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `__exit__` `__exit__(mut self)`
--- ## SchedulerWorkIteratorSplitK
`@register_passable(trivial)` `struct SchedulerWorkIteratorSplitK[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int, num_split_k: Int]` Work iterator for Scheduler warp (split-K) - owns work\_info and both states. Throttle pipeline is obtained from the scheduler. ## Fields * ​scheduler (`SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType`): * ​work\_info (`WorkInfo`): * ​consumer\_state (`PipelineState[num_stages]`): * ​producer\_state (`PipelineState[num_stages]`): * ​throttle\_pipeline (`SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` ### `ThrottlePipeline` `comptime ThrottlePipeline = SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType.ThrottlePipeline` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], work_info: WorkInfo) -> Self` Create scheduler iterator. Throttle pipeline from scheduler. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref [state_origin] self) -> AdvanceAfterWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Get next work item. **Returns:** `AdvanceAfterWorkContextSplitK` ### `signal_and_advance` `signal_and_advance(mut self)` Signal CLC throttle consumer and advance to next work request. ### `drain` `drain(mut self)` Drain all pending CLC requests before kernel exit.
--- ## TileScheduler (4)
`@register_passable(trivial)` `struct TileScheduler[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32] = Index[dtype=DType.uint32](1, 1, 1), rasterize_order: RasterOrder = RasterOrder.AlongM, block_swizzle_size: Int = 8, num_split_k: Int = 1]` ## Fields * ​locks\_ptr (`LegacyUnsafePointer[Int32]`): * ​scheduler (`TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler`): * ​total\_k\_tiles (`UInt32`): * ​k\_tiles\_per\_split (`UInt32`): * ​throttle\_pipeline (`TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BK` `comptime BK = reduction_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = reduction_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `MMA_N` `comptime MMA_N = reduction_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `ROW_SIZE` `comptime ROW_SIZE = reduction_tile_shape.__getitem__[3, DType.int64, Int](1) if (reduction_tile_shape.__getitem__[3, DType.int64, Int](0) == 128) else (TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].MMA_N // 2)` ### `ThrottlePipeline` `comptime ThrottlePipeline = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].UnderlyingScheduler.ThrottlePipeline` ### `UnderlyingScheduler` `comptime UnderlyingScheduler = TileScheduler[num_stages, cluster_shape, rasterize_order, block_swizzle_size]` ## Methods ### `__init__` `__init__(cluster_dim: StaticTuple[Int32, 3], mnk: StaticTuple[UInt32, 3], clc_response_ptr: LegacyUnsafePointer[UInt128, address_space=AddressSpace.SHARED], full_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], empty_mbar_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], throttle_storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], locks_ptr: LegacyUnsafePointer[UInt8]) -> Self` ### `init_throttle_barriers` `static init_throttle_barriers(storage_ptr: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], producer_arv_count: Int32, consumer_arv_count: Int32)` Initialize throttle pipeline barriers. Called once by elect\_one thread. ### `convert_to_splitk_work_info` `convert_to_splitk_work_info(self, work_info: WorkInfo) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `initial_work_info` `initial_work_info(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_to_next_work` `advance_to_next_work(self, mut clc_state: PipelineState[num_stages]) -> PipelineState[num_stages]` **Returns:** [`PipelineState`](/mojo/kernels/layout/tma_async/PipelineState) ### `fetch_next_work` `fetch_next_work(self, work_info: WorkInfo, consumer_state: PipelineState[num_stages]) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `advance_after_work` `advance_after_work[work_origin: MutOrigin, state_origin: MutOrigin, //](self, ref [work_origin] work_info: WorkInfo, ref [state_origin] consumer_state: PipelineState[num_stages]) -> AdvanceAfterWorkContextSplitK[work_origin, state_origin, num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Context for warps that do work THEN advance (Load/Scheduler/Epilogue). Usage: with scheduler.advance\_after\_work(work\_info, state) as current: do\_work(current) syncwarp() \# After: work\_info updated, state stepped **Returns:** `AdvanceAfterWorkContextSplitK` ### `prefetch_before_work` `prefetch_before_work[work_origin: MutOrigin, //](self, ref [work_origin] work_info: WorkInfo, mut consumer_state: PipelineState[num_stages]) -> PrefetchBeforeWorkContextSplitK[work_origin]` Context for MMA warp that prefetches BEFORE work (software pipelining). Fetches next work and steps state IMMEDIATELY (before the with block). Usage: with scheduler.prefetch\_before\_work(work\_info, state) as current: do\_mma(current) # Uses current, not prefetched \# After: work\_info updated to prefetched value **Returns:** `PrefetchBeforeWorkContextSplitK` ### `work_iterator` `work_iterator(self) -> WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Create a per-warp work iterator that owns work\_info internally. Throttle pipeline is obtained from the scheduler. **Returns:** `WorkIteratorSplitK` ### `scheduler_iterator` `scheduler_iterator(self) -> SchedulerWorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Create iterator for Scheduler warp (owns work\_info and both states). Throttle pipeline is obtained from the scheduler. **Returns:** `SchedulerWorkIteratorSplitK` ### `is_last_split` `is_last_split(self, work_tile_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `output_tile_index` `output_tile_index(self, work_info: WorkInfo) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `store_to_workspace` `store_to_workspace[accum_type: DType, workspace_layout: Layout, /, *, do_reduction: Bool = False, write_back: Bool = False](self, tmem_addr: UInt32, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], epilogue_thread_idx: UInt, reduction_tile_idx: UInt32)` ### `reduction` `reduction[accum_type: DType, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, origin], tmem_addr: UInt32, epilogue_thread_idx: UInt, work_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `wait_eq` `static wait_eq(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)` ### `wait_lt` `static wait_lt(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)` ### `arrive_set` `static arrive_set(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)`
--- ## WorkInfo (4)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​num\_k\_tiles (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `INVALID_WORK_INFO` `comptime INVALID_WORK_INFO = WorkInfo(0, 0, 0, 0, False)` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_final_split` `is_final_split(self, k_tiles_per_output_tile: UInt32) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## WorkIteratorSplitK
`@register_passable(trivial)` `struct WorkIteratorSplitK[num_stages: Int, reduction_tile_shape: IndexList[3], cluster_shape: IndexList[3, element_type=DType.uint32], rasterize_order: RasterOrder, block_swizzle_size: Int, num_split_k: Int]` Per-warp work iterator for split-K that owns work\_info and pipeline state. Throttle pipeline is obtained from the scheduler. ## Fields * ​scheduler (`WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType`): * ​work\_info (`WorkInfo`): * ​consumer\_state (`PipelineState[num_stages]`): * ​throttle\_pipeline (`WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].ThrottlePipeline`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `SchedulerType` `comptime SchedulerType = TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` ### `ThrottlePipeline` `comptime ThrottlePipeline = WorkIteratorSplitK[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k].SchedulerType.ThrottlePipeline` ## Methods ### `__init__` `__init__(scheduler: TileScheduler[num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k], work_info: WorkInfo) -> Self` Create work iterator. Throttle pipeline from scheduler. ### `has_work` `has_work(self) -> Bool` Check if there is more work to process. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `next` `next[state_origin: MutOrigin, //](ref [state_origin] self) -> AdvanceAfterWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info), origin_of(state_origin._mlir_origin.consumer_state), num_stages, reduction_tile_shape, cluster_shape, rasterize_order, block_swizzle_size, num_split_k]` Get next work item (advance AFTER work pattern). **Returns:** `AdvanceAfterWorkContextSplitK` ### `next_prefetch` `next_prefetch[state_origin: MutOrigin, //](ref [state_origin] self) -> PrefetchBeforeWorkContextSplitK[origin_of(state_origin._mlir_origin.work_info)]` Get next work item with prefetch (advance BEFORE work pattern). **Returns:** `PrefetchBeforeWorkContextSplitK` ### `throttle_signal` `throttle_signal(mut self, is_first_cta_in_cluster: Bool)` Signal CLC throttle if this is the first CTA in cluster. **Args:** * ​is\_first\_cta\_in\_cluster ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only first CTA signals to avoid duplicates.
--- ## get_num_tiles (Tile_scheduler_splitk)
`get_num_tiles(problem_shape: IndexList[3], block_tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_required_locks_buffer_size_bytes (Tile_scheduler_splitk)
`get_required_locks_buffer_size_bytes[accum_type: DType](problem_shape: IndexList[3], block_tile_shape: IndexList[3], cluster_shape: IndexList[2]) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## tile_scheduler_splitk (Tile_scheduler_splitk)
## Structs * [​`AdvanceAfterWorkContextSplitK`](./AdvanceAfterWorkContextSplitK): Context for warps that do work THEN advance (Load/Scheduler/Epilogue). * [​`PrefetchBeforeWorkContextSplitK`](./PrefetchBeforeWorkContextSplitK): Context for MMA warp that prefetches BEFORE work (software pipelining). * [​`SchedulerWorkIteratorSplitK`](./SchedulerWorkIteratorSplitK): Work iterator for Scheduler warp (split-K) - owns work\_info and both states. Throttle pipeline is obtained from the scheduler. * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo): * [​`WorkIteratorSplitK`](./WorkIteratorSplitK): Per-warp work iterator for split-K that owns work\_info and pipeline state. Throttle pipeline is obtained from the scheduler. ## Functions * [​`get_num_tiles`](./get_num_tiles): * [​`get_required_locks_buffer_size_bytes`](./get_required_locks_buffer_size_bytes):
--- ## AccumBarrier
`@register_passable(trivial)` `struct AccumBarrier[cta_group: Int]` Helper for accumulator pipeline barrier operations. Handles the different arrival patterns for single-CTA vs 2-CTA groups. Template Parameters: cta\_group: Number of CTAs cooperating (1 or 2). ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `arrive` `static arrive(pipeline: ProducerConsumerPipeline[num_stages], stage: UInt32)` Signal accumulator arrival. **Args:** * ​pipeline ([`ProducerConsumerPipeline`](/mojo/kernels/linalg/matmul/gpu/sm100/pipeline/ProducerConsumerPipeline)): The MMA output pipeline. * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Current pipeline stage.
--- ## AccumTile
`@register_passable(trivial)` `struct AccumTile[dtype: DType, size: Int]` Accumulator tile holding upper and lower fragment data. SM100 accumulators in TMEM are stored as two halves (upper 16 rows, lower 16 rows). This struct represents the complete tile being written. This is the SM100 equivalent of SM90's RegTileType - the data being written by the tile writer. Template Parameters: dtype: Data type of the fragments (typically epilogue\_dtype). size: Number of elements per fragment. ## Fields * ​upper (`SIMD[dtype, size]`): * ​lower (`SIMD[dtype, size]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(upper: SIMD[dtype, size], lower: SIMD[dtype, size]) -> Self` Create an accumulator tile from upper and lower fragments.
--- ## EpilogueApplier
`@register_passable(trivial)` `struct EpilogueApplier[MMA_M: Int, stageN: Int, num_stages: Int, repeats: Int, cta_group: Int, transpose_c: Bool]` Apply element-wise epilogue operations on register fragments. Computes global coordinates for each element and applies a lambda function. Handles different MMA layouts (A/B/D/F) and transpose modes. Template Parameters: MMA\_M: MMA M dimension. stageN: Stage width in elements. num\_stages: Number of output stages. repeats: Number of repetitions per load. cta\_group: Number of CTAs cooperating (1 or 2). transpose\_c: Whether output is transposed. ## Fields * ​coords (`EpilogueApplier[MMA_M, stageN, num_stages, repeats, cta_group, transpose_c].Coords`): * ​warp\_id (`UInt32`): * ​lane\_id (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Coords` `comptime Coords = FragmentCoords[stageN, repeats]` ## Methods ### `__init__` `__init__(warp_id: UInt32, lane_id: UInt32) -> Self` Initialize the epilogue applier. **Args:** * ​warp\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Warp ID within the CTA. * ​lane\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Lane ID within the warp. ### `compute_staged_coords` `compute_staged_coords(self, stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[UInt32, UInt32]` Compute staged row and column coordinates. **Args:** * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Current stage index. * ​c\_row ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base row coordinate. * ​c\_col ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base column coordinate. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (staged\_row, staged\_col). ### `apply_to_fragment` `apply_to_fragment[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type](self, mut frag: SIMD[epilogue_dtype, frag_size], staged_row: UInt32, staged_col: UInt32, is_upper: Bool)` Apply epilogue lambda to a fragment. **Args:** * ​frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Fragment to apply epilogue to (modified in place). * ​staged\_row ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Staged row coordinate. * ​staged\_col ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Staged column coordinate. * ​is\_upper ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether this is the upper or lower fragment. ### `apply_to_both_fragments` `apply_to_both_fragments[epilogue_dtype: DType, frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type, is_lower_frag_required: Bool](self, mut upper_frag: SIMD[epilogue_dtype, frag_size], mut lower_frag: SIMD[epilogue_dtype, frag_size], stage: UInt32, c_row: UInt32, c_col: UInt32) -> Tuple[SIMD[epilogue_dtype, frag_size], SIMD[epilogue_dtype, frag_size]]` Apply epilogue lambda to both upper and lower fragments. This is the main entry point for register-based epilogue, replacing the standalone register\_epilogue function. **Args:** * ​upper\_frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Upper fragment to apply epilogue to. * ​lower\_frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Lower fragment to apply epilogue to. * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Current stage index. * ​c\_row ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base row coordinate. * ​c\_col ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base column coordinate. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (modified upper\_frag, modified lower\_frag).
--- ## EpilogueConfig
`@register_passable(trivial)` `struct EpilogueConfig[MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, transpose_c: Bool]` Configuration for epilogue stage computations. Computes the number of stages and other parameters needed for the output epilogue based on MMA and CTA configuration. Template Parameters: MMA\_M: MMA M dimension. MMA\_N: MMA N dimension. stageN: Stage width in elements. cta\_group: Number of CTAs cooperating (1 or 2). transpose\_c: Whether output is transposed. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `bits` `comptime bits = 256` ### `cg1_num_stages` `comptime cg1_num_stages = (MMA_N // stageN)` ### `cg2_num_stages` `comptime cg2_num_stages = (MMA_N // stageN) if (MMA_M == 256) else ((MMA_N // stageN) // 2)` ### `data_paths` `comptime data_paths = 16` ### `fragment_size` `comptime fragment_size = 4` ### `is_lower_frag_required` `comptime is_lower_frag_required = (MMA_M == 64) if (cta_group == 1) else (cta_group == 1).__bool__().__invert__()` ### `num_stages` `comptime num_stages = (MMA_N // stageN) if (MMA_M == 256) else ((MMA_N // stageN) // 2) if (cta_group == 2) else EpilogueConfig[MMA_M, MMA_N, stageN, cta_group, transpose_c].cg1_num_stages`
--- ## FragmentCoords
`@register_passable(trivial)` `struct FragmentCoords[stageN: Int, repeats: Int]` Compute coordinates for fragment elements in tensor memory layout. Based on tcgen05 matrix fragment layout (16x256b): Template Parameters: stageN: Stage width in elements. repeats: Number of repetitions for wider loads. ## Fields * ​top\_upper (`StaticTuple[UInt32, 2]`): * ​bottom\_upper (`StaticTuple[UInt32, 2]`): * ​top\_lower (`StaticTuple[UInt32, 2]`): * ​bottom\_lower (`StaticTuple[UInt32, 2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `load_width` `comptime load_width = 2` ### `threads_per_row` `comptime threads_per_row = ((stageN // repeats) // 2)` ## Methods ### `__init__` `__init__(lane_id: UInt32) -> Self` Initialize fragment coordinates based on lane ID. **Args:** * ​lane\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Lane ID within the warp.
--- ## OutputStageWriter
`@register_passable(trivial)` `struct OutputStageWriter[c_type: DType, c_smem_layout: Layout, MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]` Orchestrate writing a single output stage. Coordinates TMEM read, optional epilogue, st.matrix to SMEM, and TMA store. Template Parameters: c\_type: Output data type. c\_smem\_layout: Shared memory tile layout. MMA\_M: MMA M dimension. MMA\_N: MMA N dimension. stageN: Stage width in elements. cta\_group: Number of CTAs cooperating. c\_swizzle: TMA swizzle mode. transpose\_c: Whether output is transposed. ## Fields * ​st\_writer (`OutputStageWriter[c_type, c_smem_layout, MMA_M, MMA_N, stageN, cta_group, c_swizzle, transpose_c].StWriter`): * ​warp\_id (`UInt32`): * ​lane\_id (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Config` `comptime Config = EpilogueConfig[MMA_M, MMA_N, stageN, cta_group, transpose_c]` ### `StWriter` `comptime StWriter = StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c]` ## Methods ### `__init__` `__init__(warp_id: UInt32, lane_id: UInt32) -> Self` Initialize the output stage writer. **Args:** * ​warp\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Warp ID within the CTA. * ​lane\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Lane ID within the warp. ### `write_upper_fragment` `write_upper_fragment[frag_size: Int, epilogue_dtype: DType](self, frag: SIMD[epilogue_dtype, frag_size], dst: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_offset: UInt32 = 0)` Write the upper fragment to shared memory. **Args:** * ​frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Upper fragment (already cast to epilogue\_dtype). * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination shared memory tile. * ​warp\_offset ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Additional warp-based offset for transpose mode. ### `write_lower_fragment` `write_lower_fragment[frag_size: Int, epilogue_dtype: DType](self, frag: SIMD[epilogue_dtype, frag_size], dst: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_offset: UInt32 = 0)` Write the lower fragment to shared memory. **Args:** * ​frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Lower fragment (already cast to epilogue\_dtype). * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination shared memory tile. * ​warp\_offset ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Additional warp-based offset for transpose mode.
--- ## SMemEpilogueWriter
`@register_passable(trivial)` `struct SMemEpilogueWriter[c_type: DType, c_smem_layout: Layout, num_output_stages: Int, //, epilogue_dtype: DType, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool, num_stages: Int, simd_size: Int, stage: Int, rep_frag_size: Int, compute_lambda_fn: elementwise_compute_lambda_type]` Write accumulator tile to SMEM and apply element-wise epilogue lambda. This writer handles the SMEM-based epilogue path when register\_based\_epilogue=False. Inferred from c\_tiles: c\_type, c\_smem\_layout, num\_output\_stages. Derived from layout: stageN, stage\_contiguous\_size. ## Fields * ​warp\_id (`UInt32`): * ​c\_tiles (`SMemEpilogueWriter[epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].CTileArray`): * ​M (`UInt32`): * ​N (`UInt32`): * ​c\_row (`UInt32`): * ​c\_col (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `barrier_threads` `comptime barrier_threads = (num_output_warps * WARP_SIZE)` ### `CTileArray` `comptime CTileArray = SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128]` ### `data_paths` `comptime data_paths = 16` ### `N_dim` `comptime N_dim = 0 if transpose_c else 1` ### `stage_contiguous_size` `comptime stage_contiguous_size = c_smem_layout.shape[1].value()` ### `stageN` `comptime stageN = c_smem_layout.shape[SMemEpilogueWriter[epilogue_dtype, BM, BN, MMA_M, MMA_N, cta_group, num_output_warps, c_swizzle, transpose_c, is_lower_frag_required, num_stages, simd_size, stage, rep_frag_size, compute_lambda_fn].N_dim].value()` ### `swizzle` `comptime swizzle = make_swizzle[c_type, c_swizzle]()` ### `swizzle_width` `comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())` ### `Tile` `comptime Tile = AccumTile[epilogue_dtype, rep_frag_size]` ## Methods ### `__init__` `__init__(warp_id: UInt32, c_tiles: SMemTileArrayType[c_type, c_smem_layout, num_output_stages, 128], c_shape: Tuple[UInt32, UInt32], c_coord: Tuple[UInt32, UInt32]) -> Self` Initialize the SMEM epilogue writer. ### `write_tile` `write_tile(self, tile: AccumTile[epilogue_dtype, rep_frag_size])` Write accumulator tile to SMEM and apply epilogue lambda.
--- ## StMatrixConfig
`@register_passable(trivial)` `struct StMatrixConfig[c_type: DType, stageN: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]` Configuration for st.matrix store operations. Computes the various constants needed for st.matrix operations based on the output tile configuration. Template Parameters: c\_type: Output data type (e.g., bfloat16). stageN: Stage width in elements. c\_swizzle: TMA swizzle mode. transpose\_c: Whether output is transposed. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `stmtx_simd_width` `comptime stmtx_simd_width = 4 if ((stageN % 16) == 0) else 2` ### `stsmx_lane_size` `comptime stsmx_lane_size = (16 // size_of[c_type]())` ### `stsmx_row_size` `comptime stsmx_row_size = (32 // size_of[c_type]()) if ((stageN % 16) == 0) else (16 // size_of[c_type]())` ### `swizzle_width` `comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())` ## Methods ### `make_swizzle` `static make_swizzle() -> Swizzle` Create the swizzle pattern for st.matrix operations. **Returns:** [`Swizzle`](/mojo/kernels/layout/swizzle/Swizzle): Swizzle instance for the configured swizzle mode.
--- ## StMatrixCoords
`@register_passable(trivial)` `struct StMatrixCoords[MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, transpose_c: Bool]` Compute coordinates for st.matrix operations. Encapsulates the complex coordinate calculations needed for storing accumulator fragments to shared memory. Template Parameters: MMA\_M: MMA M dimension. MMA\_N: MMA N dimension. stageN: Stage N dimension (width of each output tile). cta\_group: Number of CTAs cooperating (1 or 2). transpose\_c: Whether output is transposed. ## Fields * ​warp\_id (`UInt32`): * ​lane\_id (`UInt32`): * ​c\_row (`UInt32`): * ​c\_col (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(warp_id: UInt32, lane_id: UInt32, c_row: UInt32, c_col: UInt32) -> Self` Initialize coordinate calculator. **Args:** * ​warp\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Warp ID within the CTA. * ​lane\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Lane ID within the warp. * ​c\_row ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base row coordinate in global memory. * ​c\_col ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base column coordinate in global memory. ### `staged_row` `staged_row(self, stage: UInt32, num_stages: UInt32) -> UInt32` Compute the staged row coordinate. **Args:** * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Current stage index. * ​num\_stages ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Total number of stages. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): Row coordinate for the current stage. ### `staged_col` `staged_col(self, stage: UInt32, num_stages: UInt32) -> UInt32` Compute the staged column coordinate. **Args:** * ​stage ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Current stage index. * ​num\_stages ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Total number of stages. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): Column coordinate for the current stage. ### `smem_coord_m` `smem_coord_m(self) -> UInt32` Compute shared memory M coordinate for TMA store. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): M coordinate in shared memory tile.
--- ## StMatrixWriter
`@register_passable(trivial)` `struct StMatrixWriter[c_type: DType, c_smem_layout: Layout, stageN: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]` Write register fragments to shared memory using st.matrix. Handles the complex swizzling and addressing required for efficient shared memory writes from WGMMA accumulator fragments. Template Parameters: c\_type: Output data type. c\_smem\_layout: Shared memory tile layout. stageN: Stage width in elements. c\_swizzle: TMA swizzle mode. transpose\_c: Whether output is transposed. ## Fields * ​swizzle (`Swizzle`): * ​lane\_id (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Config` `comptime Config = StMatrixConfig[c_type, stageN, c_swizzle, transpose_c]` ### `shape0` `comptime shape0 = c_smem_layout.shape[1].value() if (not transpose_c._mlir_value) else c_smem_layout.shape[0].value()` ### `stride0` `comptime stride0 = c_smem_layout.stride[0].value()` ### `stride1` `comptime stride1 = c_smem_layout.stride[1].value()` ### `stsmx_tile_offset` `comptime stsmx_tile_offset = (StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].stride0 if transpose_c else StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].stride1 * StMatrixWriter[c_type, c_smem_layout, stageN, c_swizzle, transpose_c].Config.stsmx_row_size)` ## Methods ### `__init__` `__init__(lane_id: UInt32) -> Self` Initialize the st.matrix writer. **Args:** * ​lane\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Lane ID within the warp. ### `compute_lane_offset` `compute_lane_offset(self) -> UInt32` Compute the base lane offset for st.matrix. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): Lane offset in shared memory. ### `write_fragment` `write_fragment[frag_size: Int](self, frag: SIMD[dtype, frag_size], dst: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_offset: UInt32 = 0)` Write a fragment to shared memory using st.matrix. **Args:** * ​frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Source fragment (typically from TMEM load). * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination shared memory tile. * ​warp\_offset ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Additional warp-based offset for transpose mode.
--- ## TMAStoreCoords
`@register_passable(trivial)` `struct TMAStoreCoords[BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, c_smem_shape0: Int, stage: Int]` Compute TMA store coordinates and warp election for SM100 epilogue. Encapsulates the complex coordinate computation logic for TMA stores, including cta\_group-specific branching and warp election. Template Parameters: BM: Block M dimension. BN: Block N dimension. MMA\_M: MMA M dimension. MMA\_N: MMA N dimension. stageN: Stage width in elements. cta\_group: Number of CTAs cooperating (1 or 2). c\_smem\_shape0: Shape\[0] of shared memory tile layout. stage: Current output stage index. ## Fields * ​coord\_m (`UInt`): * ​coord\_n (`UInt`): * ​elect\_one\_warp (`Bool`): * ​c\_smem\_coord\_m (`UInt`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `CG1_TMA_BM` `comptime CG1_TMA_BM = c_smem_shape0` ### `CG2_TMA_BM` `comptime CG2_TMA_BM = c_smem_shape0 if (MMA_M == 256) else BM` ### `stage_n_offset` `comptime stage_n_offset = (stage * stageN)` ### `TMA_BM` `comptime TMA_BM = c_smem_shape0 if (MMA_M == 256) else BM if (cta_group == 2) else TMAStoreCoords[BM, BN, MMA_M, MMA_N, stageN, cta_group, c_smem_shape0, stage].CG1_TMA_BM` ## Methods ### `__init__` `__init__(c_coord: Tuple[UInt32, UInt32], warp_id: UInt32) -> Self` Compute all TMA store coordinates. **Args:** * ​c\_coord ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Output tile coordinates (m\_tile, n\_tile). * ​warp\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Current warp ID.
--- ## TMAStoreExecutor
`@register_passable(trivial)` `struct TMAStoreExecutor[c_type: DType, c_smem_layout: Layout, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, stage_contiguous_size: Int, cta_group: Int, c_swizzle: TensorMapSwizzle, transpose_c: Bool, is_lower_frag_required: Bool]` Execute TMA store from shared memory to global memory with proper tiling. Encapsulates all the complex SMEM tiling/reshaping logic for TMA stores. Handles 3 distinct paths based on transpose\_c, cta\_group, and MMA\_M: 1. transpose\_c + cta\_group==2 + MMA\_M==128: Split reshape 2. transpose\_c + other: Loop over swizzle-width tiles 3. non-transpose: Simple tile selection Template Parameters: c\_type: Output data type. c\_smem\_layout: Shared memory layout for C tile. BM: Block M dimension. BN: Block N dimension. MMA\_M: MMA M dimension. MMA\_N: MMA N dimension. stageN: Stage width in elements. stage\_contiguous\_size: Contiguous size in SMEM layout. cta\_group: Number of CTAs cooperating (1 or 2). c\_swizzle: TensorMap swizzle mode. transpose\_c: Whether output is transposed. is\_lower\_frag\_required: Whether lower fragment is used. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `c_smem_shape0` `comptime c_smem_shape0 = c_smem_layout.shape[0].value()` ### `CG1_TMA_BM` `comptime CG1_TMA_BM = TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].c_smem_shape0` ### `CG2_TMA_BM` `comptime CG2_TMA_BM = c_smem_layout.shape[0].value() if (MMA_M == 256) else BM` ### `num_c_smem_tiles` `comptime num_c_smem_tiles = ((128 // TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].swizzle_width) // 1 if is_lower_frag_required else 2)` ### `swizzle_width` `comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())` ### `TMA_BM` `comptime TMA_BM = c_smem_layout.shape[0].value() if (MMA_M == 256) else BM if (cta_group == 2) else TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].CG1_TMA_BM` ## Methods ### `execute` `static execute[c_layout: Layout, c_desc_layout: Layout](c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], store_coords: TMAStoreCoords[BM, BN, MMA_M, MMA_N, stageN, cta_group, TMAStoreExecutor[c_type, c_smem_layout, BM, BN, MMA_M, MMA_N, stageN, stage_contiguous_size, cta_group, c_swizzle, transpose_c, is_lower_frag_required].c_smem_shape0, stage], c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], warp_id: UInt32, lane: UInt32)` Execute TMA store with appropriate tiling for the configuration. **Args:** * ​c\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source shared memory tile. * ​store\_coords ([`TMAStoreCoords`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/tile_writer/TMAStoreCoords)): Precomputed TMA store coordinates. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA tensor tile for async store operations. * ​warp\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Current warp ID. * ​lane ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Current lane ID within warp.
--- ## TMEMFragment
`@register_passable(trivial)` `struct TMEMFragment[accum_type: DType, epilogue_type: DType, frag_size: Int]` Accumulator fragment pair from tensor memory. SM100 TMEM stores data in upper/lower fragment pairs due to the physical layout of tensor memory datapaths. Template Parameters: accum\_type: Accumulator data type (e.g., float32). epilogue\_type: Epilogue data type after casting (e.g., bfloat16). frag\_size: Number of elements per fragment. ## Fields * ​upper (`SIMD[accum_type, frag_size]`): * ​lower (`SIMD[accum_type, frag_size]`): * ​has\_lower (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(has_lower: Bool = True) -> Self` Initialize empty fragments. **Args:** * ​has\_lower ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether lower fragment is needed (based on MMA config). ### `cast_upper` `cast_upper(self) -> SIMD[epilogue_type, frag_size]` Cast upper fragment to epilogue type. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Upper fragment cast to epilogue\_type. ### `cast_lower` `cast_lower(self) -> SIMD[epilogue_type, frag_size]` Cast lower fragment to epilogue type. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Lower fragment cast to epilogue\_type.
--- ## TMEMReader
`@register_passable(trivial)` `struct TMEMReader[accum_type: DType, data_paths: Int = 16, bits: Int = 256, repeat: Int = 4]` Load accumulator fragments from tensor memory (TMEM). SM100 Blackwell GPUs have dedicated tensor memory for MMA accumulators. This struct encapsulates the tcgen05\_ld operations. Template Parameters: accum\_type: Accumulator data type. data\_paths: Number of datapaths (always 16 for SM100). bits: Bits per load (always 256 for SM100). repeat: Number of repetitions for wider loads. ## Fields * ​base\_addr (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `frag_size` `comptime frag_size = (((data_paths * (bits // 32)) // 32) * repeat)` ### `lower_offset` `comptime lower_offset = 1048576` ## Methods ### `__init__` `__init__(base_addr: UInt32) -> Self` Initialize TMEM reader. **Args:** * ​base\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base tensor memory address for the accumulator. ### `stage_addr` `stage_addr(self, stage: Int, stageN: Int) -> UInt32` Compute TMEM address for a given stage. **Args:** * ​stage ([`Int`](/mojo/stdlib/builtin/int/Int)): Stage index. * ​stageN ([`Int`](/mojo/stdlib/builtin/int/Int)): Stage width in elements. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): TMEM address for the stage.
--- ## TMEMToSMemWriter
`@register_passable(trivial)` `struct TMEMToSMemWriter[c_type: DType, accum_type: DType, c_smem_layout: Layout, BM: Int, BN: Int, MMA_M: Int, MMA_N: Int, stageN: Int, cta_group: Int, num_output_warps: Int, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_c: Bool = False]` Write TMEM accumulator fragments to shared memory for SM100. This is the SM100-specific equivalent of SM90's FragmentToSMemWriter. Key difference: SM100 accumulators live in Tensor Memory (TMEM), not registers, so we need tcgen05\_ld to load them first. Handles three tile reshaping cases: 1. transpose\_c + is\_lower\_frag\_required: 2 warps share swizzle blocks 2. transpose\_c + !is\_lower\_frag\_required: 4 warps, upper only 3. !transpose\_c: Simple row-major tiling Template Parameters: c\_type: Output data type (e.g., bfloat16). accum\_type: Accumulator data type (e.g., float32). c\_smem\_layout: Shared memory tile layout. BM: Block M dimension. BN: Block N dimension. MMA\_M: MMA M dimension. MMA\_N: MMA N dimension. stageN: Stage N dimension. cta\_group: Number of CTAs cooperating (1 or 2). num\_output\_warps: Number of warps participating in output. c\_swizzle: TMA swizzle mode. transpose\_c: Whether output is transposed. ## Fields * ​warp\_id (`UInt32`): * ​lane\_id (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Config` `comptime Config = EpilogueConfig[MMA_M, MMA_N, stageN, cta_group, transpose_c]` ### `data_paths` `comptime data_paths = 16` ### `stage_contiguous_size` `comptime stage_contiguous_size = c_smem_layout.shape[1].value()` ### `swizzle` `comptime swizzle = make_swizzle[c_type, c_swizzle]()` ### `swizzle_width` `comptime swizzle_width = (c_swizzle.bytes() // size_of[c_type]())` ## Methods ### `__init__` `__init__(warp_id: UInt32, lane_id: UInt32) -> Self` Initialize the TMEM to SMEM writer. **Args:** * ​warp\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Warp ID within the CTA. * ​lane\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Lane ID within the warp. ### `write_stage` `write_stage[repeat: Int, bits: Int = 256](self, tmem_addr: UInt32, stage: Int, c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128])` Write a single stage from TMEM to shared memory with tile reshaping. Automatically handles the correct tile reshaping based on transpose\_c and is\_lower\_frag\_required configuration. Template Parameters: repeat: Repeat factor for fragment loading. bits: TMEM bits width (default 256). **Args:** * ​tmem\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Base tensor memory address. * ​stage ([`Int`](/mojo/stdlib/builtin/int/Int)): Current stage index. * ​c\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Base shared memory tile (will be reshaped internally). ### `write_fragments` `write_fragments[repeat: Int](self, upper_frag: SIMD[c_type, (4 * repeat)], lower_frag: SIMD[c_type, (4 * repeat)], c_smem_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128])` Write pre-loaded fragments to shared memory with tile reshaping. Use this when fragments are loaded separately (e.g., with load\_tmem\_fragments) and need to be written after applying register-based epilogue. Template Parameters: repeat: Repeat factor matching the fragment size. **Args:** * ​upper\_frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Upper fragment (already casted to c\_type). * ​lower\_frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Lower fragment (already casted to c\_type, ignored if not needed). * ​c\_smem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Base shared memory tile (will be reshaped internally).
--- ## tile_writer
TileWriter components for SM100 matrix multiplication epilogue. This module provides modular components for the output pipeline: 1. **TMAStoreWriter**: TMA async store from shared memory to global memory 2. **StMatrixWriter**: Register to shared memory via st.matrix instructions 3. **TMEMReader**: Load accumulator data from tensor memory to registers 4. **EpilogueApplier**: Apply element-wise operations on fragments The SM100 epilogue pipeline flows as: TMEM (accumulators) → Registers → SMEM → GMEM (via TMA) Usage: \# TMA store from shared memory to global memory var tma\_writer = TMAStoreWriter[...](c_tma_op) tma\_writer.store\_tile(c\_smem\_tile, (n\_coord, m\_coord)) ## `comptime` values ### `RLayout32Bits` `comptime RLayout32Bits[layout: Layout] = RuntimeLayout[layout, element_type=DType.uint32, linear_idx_type=DType.uint32]` #### Parameters * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ### `ThreadwiseStoreWriter` `comptime ThreadwiseStoreWriter = TileWriterThreadwise[?, ?, ?]` ### `TMAStoreWriter` `comptime TMAStoreWriter = TileWriterTMA` ## Structs * [​`AccumBarrier`](./AccumBarrier): Helper for accumulator pipeline barrier operations. * [​`AccumTile`](./AccumTile): Accumulator tile holding upper and lower fragment data. * [​`EpilogueApplier`](./EpilogueApplier): Apply element-wise epilogue operations on register fragments. * [​`EpilogueConfig`](./EpilogueConfig): Configuration for epilogue stage computations. * [​`FragmentCoords`](./FragmentCoords): Compute coordinates for fragment elements in tensor memory layout. * [​`OutputStageWriter`](./OutputStageWriter): Orchestrate writing a single output stage. * [​`SMemEpilogueWriter`](./SMemEpilogueWriter): Write accumulator tile to SMEM and apply element-wise epilogue lambda. * [​`StMatrixConfig`](./StMatrixConfig): Configuration for st.matrix store operations. * [​`StMatrixCoords`](./StMatrixCoords): Compute coordinates for st.matrix operations. * [​`StMatrixWriter`](./StMatrixWriter): Write register fragments to shared memory using st.matrix. * [​`TMAStoreCoords`](./TMAStoreCoords): Compute TMA store coordinates and warp election for SM100 epilogue. * [​`TMAStoreExecutor`](./TMAStoreExecutor): Execute TMA store from shared memory to global memory with proper tiling. * [​`TMEMFragment`](./TMEMFragment): Accumulator fragment pair from tensor memory. * [​`TMEMReader`](./TMEMReader): Load accumulator fragments from tensor memory (TMEM). * [​`TMEMToSMemWriter`](./TMEMToSMemWriter): Write TMEM accumulator fragments to shared memory for SM100. ## Functions * [​`load_tmem_fragments`](./load_tmem_fragments): Load upper and lower fragments from TMEM and cast to epilogue type. * [​`shared_memory_epilogue`](./shared_memory_epilogue): Apply element-wise epilogue to non-transposed shared memory tile. * [​`shared_memory_epilogue_transpose`](./shared_memory_epilogue_transpose): Apply element-wise epilogue to transposed shared memory tile. * [​`store_fragment_to_smem`](./store_fragment_to_smem): Store a fragment to shared memory using st.matrix. * [​`tma_store_with_pipeline`](./tma_store_with_pipeline): Perform TMA store with pipelined commit and wait. * [​`tma_wait_pipelined`](./tma_wait_pipelined): Wait for TMA stores with pipelining.
--- ## load_tmem_fragments
`load_tmem_fragments[accum_type: DType, epilogue_type: DType, frag_size: Int, is_lower_required: Bool, data_paths: Int = 16, bits: Int = 256, repeat: Int = 1](tmem_addr: UInt32) -> Tuple[SIMD[epilogue_type, (frag_size * repeat)], SIMD[epilogue_type, (frag_size * repeat)]]` Load upper and lower fragments from TMEM and cast to epilogue type. This encapsulates the common pattern of loading accumulator data from tensor memory, waiting for completion, and casting to output type. Template Parameters: accum\_type: Accumulator data type (e.g., float32). epilogue\_type: Output data type after casting (e.g., bfloat16). frag\_size: Base fragment size per warp. is\_lower\_required: Whether lower fragment is needed. data\_paths: TMEM data paths (default 16). bits: TMEM bits width (default 256). repeat: Repeat factor for larger fragments. **Args:** * ​tmem\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Tensor memory address for this stage. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (upper\_casted, lower\_casted) SIMD fragments.
--- ## shared_memory_epilogue (Tile_writer)
`shared_memory_epilogue[MMA_M: UInt, data_paths: UInt, num_stages: UInt, stage: UInt, stageN: UInt, c_type: DType, shared_n: UInt, simd_size: UInt, c_smem_upper_layout: Layout, c_smem_lower_layout: Layout, swizzle: Swizzle, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: UInt](M: UInt32, N: UInt32, c_col: UInt, c_row: UInt, c_smem_warp_tile_upper: LayoutTensor[c_type, c_smem_upper_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_smem_warp_tile_lower: LayoutTensor[c_type, c_smem_lower_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Apply element-wise epilogue to non-transposed shared memory tile. Handles the non-transpose case for SMEM-based epilogue. Processes upper and lower fragments separately with proper coordinate mapping. Template Parameters: MMA\_M: MMA M dimension. data\_paths: Number of data paths (typically 16). num\_stages: Total number of output stages. stage: Current output stage index. stageN: Stage width in elements. c\_type: Output data type. shared\_n: Shared memory N dimension. simd\_size: SIMD width for vectorized access. c\_smem\_upper\_layout: Layout for upper fragment tile. c\_smem\_lower\_layout: Layout for lower fragment tile. swizzle: Swizzle pattern for SMEM access. compute\_lambda\_fn: Element-wise compute function. num\_output\_warps: Number of warps participating. **Args:** * ​M ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Output M dimension. * ​N ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Output N dimension. * ​c\_col ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Base column coordinate. * ​c\_row ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Base row coordinate. * ​c\_smem\_warp\_tile\_upper ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Upper fragment shared memory tile. * ​c\_smem\_warp\_tile\_lower ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Lower fragment shared memory tile.
--- ## shared_memory_epilogue_transpose (Tile_writer)
`shared_memory_epilogue_transpose[stage: UInt, stageN: UInt, c_type: DType, c_smem_layout: Layout, swizzle: Swizzle, compute_lambda_fn: elementwise_compute_lambda_type, num_output_warps: UInt, warp_dim: UInt, MMA_M: Int, BN: Int, cta_group: Int](M: UInt32, N: UInt32, c_col: UInt, c_row: UInt, c_smem: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_i: UInt, warp_j: UInt)` Apply element-wise epilogue to transposed shared memory tile. Handles the transpose\_c case for SMEM-based epilogue. Supports two warp configurations based on warp\_dim parameter. Template Parameters: stage: Current output stage index. stageN: Stage width in elements. c\_type: Output data type. c\_smem\_layout: Shared memory tile layout. swizzle: Swizzle pattern for SMEM access. compute\_lambda\_fn: Element-wise compute function. num\_output\_warps: Number of warps participating. warp\_dim: Warp dimension configuration (1 or 2). MMA\_M: MMA M dimension. BN: Block N dimension. cta\_group: Number of CTAs cooperating. **Args:** * ​M ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Output M dimension. * ​N ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Output N dimension. * ​c\_col ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Base column coordinate. * ​c\_row ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Base row coordinate. * ​c\_smem ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Shared memory tile. * ​warp\_i ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Warp index i. * ​warp\_j ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Warp index j.
--- ## store_fragment_to_smem
`store_fragment_to_smem[swizzle: Swizzle, stageN: Int, transpose_c: Bool = False, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B](vec: SIMD[dtype, size], dst: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], warp_offset: UInt32 = 0)` Store a fragment to shared memory using st.matrix. This function provides a static interface compatible with stsm\_helper, delegating to the underlying st.matrix operations. Template Parameters: swizzle: Pre-computed swizzle pattern. stageN: Stage width in elements. transpose\_c: Whether output is transposed. c\_swizzle: TMA swizzle mode (for configuration). **Args:** * ​vec ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Source SIMD fragment. * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination shared memory tile. * ​warp\_offset ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Additional warp-based offset for transpose mode.
--- ## tma_store_with_pipeline
`tma_store_with_pipeline[c_type: DType, c_layout: Layout, c_desc_layout: Layout, is_last_stage: Bool](c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout], src: LayoutTensor[c_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], coords: Tuple[UInt, UInt])` Perform TMA store with pipelined commit and wait. Encapsulates the common SM100 output pattern: 1. fence\_async\_view\_proxy() 2. async\_store() 3. commit\_group() 4. wait\_group() with pipelining Template Parameters: c\_type: Output data type. c\_layout: Global memory layout for C. c\_desc\_layout: TMA descriptor layout for C. is\_last\_stage: If True, wait for all; else keep 1 in flight. **Args:** * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA tensor tile descriptor. * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source shared memory tile. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Destination coordinates in global memory.
--- ## tma_wait_pipelined
`tma_wait_pipelined[c_type: DType, c_layout: Layout, c_desc_layout: Layout, is_last_stage: Bool](c_tma_op: TMATensorTile[c_type, c_layout, c_desc_layout])` Wait for TMA stores with pipelining. For SM100 output pipeline: * Non-last stages: Keep 1 store in flight for pipelining * Last stage: Wait for all stores to complete Template Parameters: c\_type: Output data type. c\_layout: Global memory layout for C. c\_desc\_layout: TMA descriptor layout for C. is\_last\_stage: If True, wait for all; else keep 1 in flight. **Args:** * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA tensor tile descriptor.
--- ## create_matmul_configs_ampere
`create_matmul_configs_ampere[key: String, a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool]() -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## get_dispatch_table
`get_dispatch_table[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool]() -> Dict[String, MatmulConfig[a_type, b_type, c_type, transpose_b], default_comp_time_hasher]` **Returns:** `Dict`
--- ## dispatch (Dispatch)
## Functions * [​`create_matmul_configs_ampere`](./create_matmul_configs_ampere): * [​`get_dispatch_table`](./get_dispatch_table):
--- ## sm80
Provides the CPU Hopper backend implementations for matmuls. ## Modules * [​`dispatch`](./dispatch/):
--- ## dispatch (3)
## `comptime` values ### `DISPATCH_HIT` `comptime DISPATCH_HIT = 1` ### `DISPATCH_MISS` `comptime DISPATCH_MISS = 0` ### `llama_405b_fp8_list` `comptime llama_405b_fp8_list = List[TuningConfigSM90](TuningConfigSM90(64, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(64, 128, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(128, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(128, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(256, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(512, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, 16384, 2048, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(8, (H100 // 8))), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(64, 2304, 16384, IndexList[3, DType.int64](64, 48, 32, Tuple[]()), Index(64, 48, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(128, 2304, 16384, IndexList[3, DType.int64](64, 48, 32, Tuple[]()), Index(64, 48, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(256, 2304, 16384, IndexList[3, DType.int64](64, 96, 32, Tuple[]()), Index(64, 96, 128), 4, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(512, 2304, 16384, IndexList[3, DType.int64](64, 144, 32, Tuple[]()), Index(128, 144, 128), 4, Index(1, 1, 1), 2, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, 2304, 16384, IndexList[3, DType.int64](64, 144, 32, Tuple[]()), Index(128, 144, 128), 4, Index(1, 1, 1), 2, False, OptionalReg[IndexList[2]](Index(H100.sm_count, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(2048, 2304, 16384, IndexList[3, DType.int64](64, 144, 32, Tuple[]()), Index(128, 144, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(16, 8)), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, 2304, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(64, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(64, 128, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(128, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(128, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(256, 13312, 16384, IndexList[3, DType.int64](64, 208, 32, Tuple[]()), Index(128, 208, 128), 4, Index(1, 2, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(512, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, 13312, 16384, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(8, (H100 // 8))), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(64, 16384, 6656, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(64, 128, 128), 8, Index(1, 1, 1), 1, False, OptionalReg[IndexList[2]](Index(128, 1)), MatmulSchedule.DS_SCHEDULER, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, 16384, 6656, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, 16384, 6656, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 4, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(8, (H100 // 8))), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), Tuple[]())` ### `llama_405b_fp8_table` `comptime llama_405b_fp8_table = Table[TuningConfigSM90](llama_405b_fp8_list, "llama_405b_fp8")` ### `llama_8b_fp8_list` `comptime llama_8b_fp8_list = List[TuningConfigSM90](TuningConfigSM90(128, -1, -1, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(64, 128, 128), 8, Index(1, 1, 1), 1, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(1024, -1, -1, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 6, Index(1, 1, 1), 2, True, OptionalReg[IndexList[2]](None), MatmulSchedule.NONE, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), TuningConfigSM90(MAX_M, -1, -1, IndexList[3, DType.int64](64, 128, 32, Tuple[]()), Index(128, 128, 128), 6, Index(2, 1, 1), 2, True, OptionalReg[IndexList[2]](Index(8, (H100 // 8))), MatmulSchedule.TILE2D, OptionalReg[Int](None), OptionalReg[RasterOrder](None)), Tuple[]())` ### `llama_8b_fp8_table` `comptime llama_8b_fp8_table = Table[TuningConfigSM90](llama_8b_fp8_list, "llama_8b_fp8")` ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `MAX_M` `comptime MAX_M = Int.MAX` ## Functions * [​`matmul_dispatch_sm90`](./matmul_dispatch_sm90): * [​`matmul_dispatch_sm90_bf16_fp32`](./matmul_dispatch_sm90_bf16_fp32): * [​`matmul_dispatch_sm90_fp8`](./matmul_dispatch_sm90_fp8):
--- ## matmul_dispatch_sm90
`matmul_dispatch_sm90[c_type: DType, a_type: DType, b_type: DType, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## matmul_dispatch_sm90_bf16_fp32
`matmul_dispatch_sm90_bf16_fp32[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## matmul_dispatch_sm90_fp8
`matmul_dispatch_sm90_fp8[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = True, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, pdl_level: PDLLevel = PDLLevel()](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## default_config_sm90
`default_config_sm90[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool, wgmma_shape: IndexList[3]]() -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## grouped_matmul_sm90
`grouped_matmul_sm90[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, //, *, transpose_b: Bool = True, wgmma_shape: IndexList[3] = Index(64, 256, 16), config: MatmulConfig[a_type, b_type, c_type, transpose_b] = default_config_sm90[a_type, b_type, c_type, transpose_b, wgmma_shape](), elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: NDBuffer[c_type, 2, MutAnyOrigin, c_shape], a: NDBuffer[a_type, 2, MutAnyOrigin, a_shape], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], max_num_tokens_per_expert: Int, b: NDBuffer[b_type, 3, MutAnyOrigin, b_shape], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], num_active_experts: Int, ctx: DeviceContext)`
--- ## grouped_matmul (3)
## Functions * [​`default_config_sm90`](./default_config_sm90): * [​`grouped_matmul_sm90`](./grouped_matmul_sm90):
--- ## sm90
Provides the Nvidia Hopper backend implementations for matmuls. ## Modules * [​`dispatch`](./dispatch/): * [​`grouped_matmul`](./grouped_matmul/): * [​`matmul`](./matmul/): * [​`matmul_kernel_persistent`](./matmul_kernel_persistent/): * [​`matmul_kernels`](./matmul_kernels/): * [​`matmul_output`](./matmul_output/): * [​`ring_buffer`](./ring_buffer/): Ring buffer implementation for producer-consumer synchronization in GPU kernels. * [​`testbed`](./testbed/): * [​`tile_loader`](./tile_loader/): TileLoader module for efficient tile loading in GPU matrix multiplication. * [​`tile_writer`](./tile_writer/): TileWriter module for efficient tile writing in GPU matrix multiplication. * [​`tuning_configs`](./tuning_configs/):
--- ## matmul (5)
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Functions * [​`warp_specialize_gemm_with_multicasting`](./warp_specialize_gemm_with_multicasting): Unified dispatcher for all matmul kernel variants. * [​`warp_specialize_gemm_with_multicasting_splitk`](./warp_specialize_gemm_with_multicasting_splitk):
--- ## warp_specialize_gemm_with_multicasting
`warp_specialize_gemm_with_multicasting[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], grid_shape: OptionalReg[IndexList[2]] = None, use_tma_store: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, schedule: MatmulSchedule = MatmulSchedule.NONE, hilbert_swizzle: Bool = False, splits: Int = 0, raster_order: RasterOrder = RasterOrder.AlongM](c_device: NDBuffer[c_type, 2, origin, c_shape], a_device: NDBuffer[a_type, 2, origin, a_shape], b_device: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)` Unified dispatcher for all matmul kernel variants.
--- ## warp_specialize_gemm_with_multicasting_splitk
`warp_specialize_gemm_with_multicasting_splitk[c_type: DType, c_shape: DimList, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, *, transpose_b: Bool, config: MatmulConfig[a_type, b_type, c_type, transpose_b], splits: Int, raster_order: RasterOrder, use_tma_store: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None](c_device: NDBuffer[c_type, 2, origin, c_shape], a_device: NDBuffer[a_type, 2, origin, a_shape], b_device: NDBuffer[b_type, 2, origin, b_shape], ctx: DeviceContext)`
--- ## matmul_kernel_persistent
--- ## HopperMatmulSM90Kernel
`struct HopperMatmulSM90Kernel[a_type: DType, b_type: DType, c_type: DType, a_layout: Layout, b_layout: Layout, c_layout: Layout, c_smem_layout: Layout, block_tile_shape: IndexList[3], wgmma_shape: IndexList[3], cluster_shape: StaticTuple[Int32, 3], num_pipeline_stages: Int, num_threads: Int = 128, transpose_b: Bool = True, a_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, b_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, c_swizzle: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE, partitioned_multicast: Bool = False, use_tma_store: Bool = False, promotion_frequency: Int = 1, pdl_level: PDLLevel = PDLLevel(), elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, hilbert_swizzle: Bool = False]` Hopper SM90 Matrix Multiplication kernel optimized for NVIDIA H100 GPUs. This kernel implements a highly optimized matrix multiplication (GEMM) using: * Tensor Memory Accelerator (TMA) for efficient global-to-shared memory transfers * Warp Group Matrix Multiply Accumulate (WGMMA) instructions for tensor cores * Multi-stage software pipelining for overlapping compute and memory operations * Producer-consumer model with separate warp groups for loading and computing Template Parameters: a\_type, b\_type, c\_type: Data types for input and output matrices a\_layout, b\_layout, c\_layout: Memory layouts for matrices c\_smem\_layout: Shared memory layout for output tile block\_tile\_shape: Tile dimensions \[M, N, K] processed by each thread block wgmma\_shape: Dimensions for each WGMMA instruction \[M, N, K] cluster\_shape: Thread block cluster dimensions for distributed shared memory num\_pipeline\_stages: Number of stages in the software pipeline (typically 3-7) num\_threads: Number of threads per block (must be multiple of 128) transpose\_b: Whether B matrix is transposed (required to be True) a\_swizzle, b\_swizzle: Memory swizzling for bank-conflict-free access c\_swizzle: Swizzling for output writes partitioned\_multicast: Enable partitioned multicast for large tiles use\_tma\_store: Use TMA for storing output (vs regular stores) promotion\_frequency: How often to promote FP8 accumulation to higher precision pdl\_level: Programmatic Dependency Launch (PDL) level elementwise\_lambda\_fn: Optional epilogue function elementwise\_compute\_lambda\_fn: Optional compute function hilbert\_swizzle: Use Hilbert curve for thread block scheduling ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `a_smem_layout` `comptime a_smem_layout = tile_layout_k_major[a_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BM, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BK, a_swizzle]()` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ### `AccumRegTileType` `comptime AccumRegTileType = LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `b_smem_layout` `comptime b_smem_layout = tile_layout_k_major[b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BN, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BK, b_swizzle]()` ### `BK` `comptime BK = block_tile_shape.__getitem__[3, DType.int64, Int](2)` ### `BM` `comptime BM = block_tile_shape.__getitem__[3, DType.int64, Int](0)` ### `BN` `comptime BN = block_tile_shape.__getitem__[3, DType.int64, Int](1)` ### `c_frag_size` `comptime c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // 128)` ### `cluster_size` `comptime cluster_size = Int.__init__[Int32](((cluster_shape.__getitem__[3, Int](0) * cluster_shape.__getitem__[3, Int](1)) * cluster_shape.__getitem__[3, Int](2)))` ### `num_consumer` `comptime num_consumer = ((num_threads // 128) - 1)` ### `num_consumer_threads` `comptime num_consumer_threads = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_consumer * 128)` ### `num_m_mmas` `comptime num_m_mmas = ((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BM // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_consumer)` ### `num_n_mmas` `comptime num_n_mmas = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BN // wgmma_shape.__getitem__[3, DType.int64, Int](1))` ### `RingBuffer` `comptime RingBuffer[tma_transfer: Bool = True] = RingBuffer[a_type, b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].a_smem_layout, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].b_smem_layout, num_pipeline_stages, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_consumer, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].cluster_size, tma_transfer]` #### Parameters * ​tma\_transfer ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): ### `RingBufferConsumer` `comptime RingBufferConsumer[origin: MutOrigin, tma_transfer: Bool] = RingBufferConsumer[origin, RingBuffer[a_type, b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].a_smem_layout, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].b_smem_layout, num_pipeline_stages, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_consumer, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].cluster_size, tma_transfer]]` #### Parameters * ​origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): * ​tma\_transfer ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): ### `SMem` `comptime SMem = HopperMatmulSM90Kernel_SMem[a_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].a_smem_layout, b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].b_smem_layout, c_type, c_smem_layout, num_pipeline_stages]` ### `WgmmaOp` `comptime WgmmaOp = TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b]` ## Methods ### `validate_constraints` `static validate_constraints()` Validate common constraints for all kernel variants. ### `pipeline_init` `static pipeline_init()` Initialize pipeline synchronization barriers. This function ensures that all pipeline initialization (barriers, shared memory) is visible to all thread blocks in the cluster before proceeding. This is critical for correct producer-consumer synchronization. For multi-cluster configurations, uses fence and cluster sync. For single block, uses a simple barrier. ### `finalize_kernel` `static finalize_kernel()` Common finalization for all kernel variants. ### `multicast_mask` `static multicast_mask(rank_m: UInt, rank_n: UInt) -> Tuple[Int32, Int32]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `common_kernel_init` `static common_kernel_init() -> Tuple[UInt, UInt, UInt, UInt, UInt, Bool]` Common initialization for all kernel variants. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (warp\_group\_idx, warp\_group\_thread\_idx, rank\_m, rank\_n, warp\_id, lane\_predicate). ### `build_ring_buffer` `static build_ring_buffer[tma_transfer: Bool = True](smem: HopperMatmulSM90Kernel_SMem[a_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].a_smem_layout, b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].b_smem_layout, c_type, c_smem_layout, num_pipeline_stages], warp_group_thread_idx: UInt) -> RingBuffer[a_type, b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].a_smem_layout, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].b_smem_layout, num_pipeline_stages, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_consumer, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].cluster_size, tma_transfer]` Create ring buffer for producer-consumer synchronization. **Returns:** [`RingBuffer`](/mojo/kernels/linalg/matmul/gpu/amd/ring_buffer/RingBuffer) ### `setup_producer` `static setup_producer() -> Int` Setup producer warp group by deallocating registers. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Number of registers deallocated. ### `setup_consumer` `static setup_consumer(warp_group_idx: UInt) -> Tuple[UInt, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].AccumRegTileType, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].AccumRegTileType]` Setup consumer warp group. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (local\_warp\_group\_idx, c\_reg\_tile, final\_c\_reg\_tile). ### `get_block_swizzle` `static get_block_swizzle(lut_ptr: LegacyUnsafePointer[UInt32] = LegacyUnsafePointer[UInt32, AddressSpace.GENERIC, True, MutAnyOrigin]()) -> IndexList[2, element_type=DType.uint32]` Calculate block swizzle for better L2 cache locality. **Args:** * ​lut\_ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Lookup table for Hilbert curve block scheduling (optional). **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): Swizzled block indices. ### `consumer_output` `static consumer_output[custom_elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = elementwise_lambda_fn](c_tma_op: TMATensorTile[c_type, layout, desc_layout], c: LayoutTensor[c_type, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_tile: LayoutTensor[c_type, c_smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=NVIDIASharedMemoryBasePtr.alignment], output_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], warp_group_thread_idx: UInt, local_warp_group_idx: UInt, local_thread_idx: UInt, block_y: Int, block_x: Int)` Handle consumer output by writing GEMM results to global memory. ### `build_tma_loaders` `static build_tma_loaders[a_tile_layout: Layout, b_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, //](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], rank_m: UInt, rank_n: UInt) -> Tuple[TileLoaderTMA[a_tma_op, a_type, a_tile_layout, a_desc_layout, BK=UInt(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BK), cluster_size=cluster_shape.__getitem__[3, Int](0), use_partitioned_multicast=partitioned_multicast], TileLoaderTMA[b_tma_op, b_type, b_tile_layout, b_desc_layout, BK=UInt(HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BK), cluster_size=cluster_shape.__getitem__[3, Int](1), use_partitioned_multicast=partitioned_multicast]]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `build_cpasync_loaders` `static build_cpasync_loaders[k_align: Int, vector_size: Int = (k_align // size_of[a_type]()), num_threads_per_row: Int = (HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].BK // vector_size), thread_layout: Layout = Layout.row_major((WARPGROUP_SIZE // num_threads_per_row), num_threads_per_row)](a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin]) -> Tuple[TileLoaderCPAsync[a_type, a_layout, thread_layout, a_swizzle, vector_size], TileLoaderCPAsync[b_type, b_layout, thread_layout, b_swizzle, vector_size]]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `producer_main_loop` `static producer_main_loop[a_loader_type: TileLoader, b_loader_type: TileLoader, //, num_k_iters: Int](m_coord: UInt, n_coord: UInt, k_coord: UInt, a_loader: a_loader_type, b_loader: b_loader_type, mut ring_buffer: RingBuffer[a_loader_type._dtype, b_loader_type._dtype, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer])` Polymorphic A and B Tile Loader, works with both TMA and CPAsync. ### `run` `static run[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b: LayoutTensor[b_type, b_layout, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], lut_ptr: LegacyUnsafePointer[UInt32])` Main kernel entry point for matrix multiplication. This kernel implements a producer-consumer pattern where: * One warp group (producer) loads tiles from global memory using TMA * Multiple warp groups (consumers) perform matrix multiplication using tensor cores The kernel uses software pipelining to overlap memory transfers with computation, achieving high throughput on Hopper GPUs. **Args:** * ​a\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix A. * ​b\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix B. * ​c\_tma\_op ([`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)): TMA descriptor for matrix C. * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix A. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input matrix B. * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output matrix C. * ​lut\_ptr ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Lookup table for Hilbert curve block scheduling (optional). ### `run_splitk` `static run_splitk[a_tile_layout: Layout, b_tile_layout: Layout, c_tma_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout, splits: Int, raster_order: RasterOrder](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tma_layout, c_desc_layout], c: LayoutTensor[c_type, c_layout, MutAnyOrigin], workspace_buffer: NDBuffer[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, 3, MutAnyOrigin], locks_ptr: LegacyUnsafePointer[UInt8], problem_shape: IndexList[3])` Split-K variant of the kernel for better load balancing on small problems. ### `run_grouped` `static run_grouped[a_tile_layout: Layout, b_tile_layout: Layout, c_tile_layout: Layout, a_desc_layout: Layout, b_desc_layout: Layout, c_desc_layout: Layout](a_tma_op: TMATensorTile[a_type, a_tile_layout, a_desc_layout], b_tma_op: TMATensorTile[b_type, b_tile_layout, b_desc_layout], c_tma_op: TMATensorTile[c_type, c_tile_layout, c_desc_layout], a_offsets: NDBuffer[DType.uint32, 1, MutAnyOrigin], expert_ids: NDBuffer[DType.int32, 1, MutAnyOrigin], c: LayoutTensor[c_type, c_layout, MutAnyOrigin])` Grouped matmul variant for MoE (Mixture of Experts) models. This variant handles multiple experts where each expert processes a subset of tokens. The a\_offsets array indicates token boundaries for each expert. ### `consumer_main_loop` `static consumer_main_loop[ring_buffer_origin: MutOrigin, //, num_k_iters: Int](wgmma_op: TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: UInt, final_c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], mut ring_buffer: RingBufferConsumer[ring_buffer_origin, RingBuffer[a_type, b_type, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].a_smem_layout, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].b_smem_layout, num_pipeline_stages, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_consumer, HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].cluster_size, tma_transfer]])` Main computation loop for consumer warp groups. This function implements the core matrix multiplication using tensor cores. It consumes tiles from the ring buffer and accumulates results using WGMMA (Warp Group Matrix Multiply Accumulate) instructions. For FP8 data types, it periodically promotes intermediate results to higher precision to maintain accuracy. **Args:** * ​wgmma\_op ([`TensorCoreAsync`](/mojo/kernels/layout/tensor_core_async/TensorCoreAsync)): Tensor core operator for matrix multiplication. * ​local\_warp\_group\_idx ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Index of this consumer warp group (0-based). * ​final\_c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Final accumulation register tile (for FP8 promotion). * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Working accumulation register tile. * ​ring\_buffer ([`RingBufferConsumer`](/mojo/kernels/linalg/matmul/gpu/sm90/matmul_kernels/#ringbufferconsumer)): Consumer handle for synchronized tile access. ### `promote_to_cuda_cores` `static promote_to_cuda_cores(c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], final_c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL])` Promote FP8 accumulation to higher precision using CUDA cores. When using FP8 data types, tensor cores accumulate in limited precision. To maintain accuracy over many accumulations, we periodically add the intermediate results to a higher-precision accumulator using CUDA cores. This technique is commonly used in production libraries like cuBLAS to achieve both high performance (from FP8 tensor cores) and good accuracy. **Args:** * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Current accumulation from tensor cores. * ​final\_c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Higher-precision accumulator (updated in place). ### `wgmma` `static wgmma(wgmma_op: TensorCoreAsync[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, a_type, b_type, wgmma_shape, a_swizzle, b_swizzle, transpose_b], local_warp_group_idx: UInt, a_tile: LayoutTensor[a_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_tile: LayoutTensor[b_type, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].accum_type, Layout.row_major((HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_m_mmas * HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].num_n_mmas), HopperMatmulSM90Kernel[a_type, b_type, c_type, a_layout, b_layout, c_layout, c_smem_layout, block_tile_shape, wgmma_shape, cluster_shape, num_pipeline_stages, num_threads, transpose_b, a_swizzle, b_swizzle, c_swizzle, partitioned_multicast, use_tma_store, promotion_frequency, pdl_level, elementwise_lambda_fn, elementwise_compute_lambda_fn, hilbert_swizzle].c_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL])`
--- ## HopperMatmulSM90Kernel_SMem
`@register_passable(trivial)` `struct HopperMatmulSM90Kernel_SMem[a_type: DType, a_layout: Layout, b_type: DType, b_layout: Layout, c_type: DType, c_layout: Layout, num_pipeline_stages: Int]` Shared memory layout for Hopper SM90 matrix multiplication kernel. This struct manages the shared memory allocation for: * Input tiles (A and B matrices) with multi-stage pipelining * Output tile (C matrix) for accumulation * Synchronization barriers for producer-consumer coordination The memory is organized to support asynchronous loads and efficient bank-conflict-free access patterns for tensor core operations. ## Fields * ​a\_tiles (`HopperMatmulSM90Kernel_SMem[a_type, a_layout, b_type, b_layout, c_type, c_layout, num_pipeline_stages].ATileArray`): * ​b\_tiles (`HopperMatmulSM90Kernel_SMem[a_type, a_layout, b_type, b_layout, c_type, c_layout, num_pipeline_stages].BTileArray`): * ​c\_tile (`HopperMatmulSM90Kernel_SMem[a_type, a_layout, b_type, b_layout, c_type, c_layout, num_pipeline_stages].CTile`): * ​full\_mbar (`HopperMatmulSM90Kernel_SMem[a_type, a_layout, b_type, b_layout, c_type, c_layout, num_pipeline_stages].PipelineBarrier`): * ​empty\_mbar (`HopperMatmulSM90Kernel_SMem[a_type, a_layout, b_type, b_layout, c_type, c_layout, num_pipeline_stages].PipelineBarrier`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATileArray` `comptime ATileArray = SMemTileArrayType[a_type, a_layout, num_pipeline_stages, NVIDIASharedMemoryBasePtr.alignment]` ### `BTileArray` `comptime BTileArray = SMemTileArrayType[b_type, b_layout, num_pipeline_stages, NVIDIASharedMemoryBasePtr.alignment]` ### `CTile` `comptime CTile = LayoutTensor[c_type, c_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=NVIDIASharedMemoryBasePtr.alignment]` ### `PipelineBarrier` `comptime PipelineBarrier = SMemArrayType[SharedMemBarrier, num_pipeline_stages]` ### `SMM` `comptime SMM = SharedMemoryManager[NVIDIASharedMemoryBasePtr]` ## Methods ### `__init__` `__init__() -> Self` ### `pipeline_storage_size` `static pipeline_storage_size() -> Int` Calculate the memory size for all pipeline stages. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `output_storage_size` `static output_storage_size() -> Int` Calculate the memory size for output tile. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `storage_size` `static storage_size() -> Int` Calculate the total storage size. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## find_K_alignment_upto_16B
`find_K_alignment_upto_16B(row_bytes_arg: Int) -> Int` Find alignment among 1B, 2B, 4B, 16B based on the row's bytes. This function determines the largest power-of-2 alignment (up to 16 bytes) that evenly divides the given row size. This is used to determine the optimal vector size for cp.async operations when K dimension alignment doesn't meet TMA requirements. **Args:** * ​row\_bytes\_arg ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of bytes in a row (K \* sizeof(element)). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Alignment in bytes (1, 2, 4, 8, or 16).
--- ## matmul_kernels (Matmul_kernels)
## Structs * [​`HopperMatmulSM90Kernel`](./HopperMatmulSM90Kernel): Hopper SM90 Matrix Multiplication kernel optimized for NVIDIA H100 GPUs. * [​`HopperMatmulSM90Kernel_SMem`](./HopperMatmulSM90Kernel_SMem): Shared memory layout for Hopper SM90 matrix multiplication kernel. ## Functions * [​`find_K_alignment_upto_16B`](./find_K_alignment_upto_16B): Find alignment among 1B, 2B, 4B, 16B based on the row's bytes.
--- ## MatmulTileWriter
`@register_passable(trivial)` `struct MatmulTileWriter[dtype: DType, layout: Layout, address_space: AddressSpace, element_layout: Layout, layout_int_type: DType, linear_idx_type: DType, masked: Bool, alignment: Int, smem_tile_layout: Layout, //, *, BM: Int, BN: Int, swizzle: TensorMapSwizzle, wgmma_shape: IndexList[3], num_consumer: Int = 1, use_tma_store: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None]` ## Fields * ​tensor (`MatmulTileWriter[BM=BM, BN=BN, swizzle=swizzle, wgmma_shape=wgmma_shape, num_consumer=num_consumer, use_tma_store=use_tma_store, elementwise_lambda_fn=elementwise_lambda_fn, elementwise_compute_lambda_fn=elementwise_compute_lambda_fn].CTensorType`): * ​smem\_tile (`LayoutTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]`): * ​warp\_group\_thread\_idx (`UInt`): * ​local\_warp\_group\_idx (`UInt`): * ​local\_thread\_idx (`UInt`): * ​block\_y (`Int`): * ​block\_x (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `CTensorType` `comptime CTensorType = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` ### `frag_size` `comptime frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // WARPGROUP_SIZE)` ### `lambda_type` `comptime lambda_type = fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], mut SIMD[dtype, width]) capturing -> None` ### `N` `comptime N = layout.shape[1].value()` ### `num_consumer_threads` `comptime num_consumer_threads = (num_consumer * WARPGROUP_SIZE)` ### `num_m_mmas` `comptime num_m_mmas = ((BM // wgmma_shape.__getitem__[3, DType.int64, Int](0)) // num_consumer)` ### `num_n_mmas` `comptime num_n_mmas = (BN // wgmma_shape.__getitem__[3, DType.int64, Int](1))` ### `simd_size` `comptime simd_size = simd_width_of[dtype]()` ### `WG_BM` `comptime WG_BM = smem_tile_layout.shape[0].value()` ### `WG_BN` `comptime WG_BN = smem_tile_layout.shape[1].value()` ## Methods ### `__init__` `__init__(tensor: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], smem_tile: LayoutTensor[dtype, smem_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_group_thread_idx: UInt, local_warp_group_idx: UInt, local_thread_idx: UInt, block_y: Int, block_x: Int) -> Self` ### `write_tile` `write_tile[tma_layout: Layout, desc_layout: Layout, accum_type: DType, reg_tile_layout: Layout, //](self, tma_op: TMATensorTile[dtype, tma_layout, desc_layout], reg_tile: LayoutTensor[accum_type, reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL])` Write output from registers to global memory. Selects optimized st.matrix path for bf16 when constraints are met, otherwise uses general register-to-global path.
--- ## matmul_output (Matmul_output)
## Structs * [​`MatmulTileWriter`](./MatmulTileWriter):
--- ## ConsumerTiles (Ring_buffer)
`@register_passable(trivial)` `struct ConsumerTiles[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_consumers: Int, cluster_size: Int, tma_transfer: Bool, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer]]]` Context manager for consumer access to ring buffer tiles. This struct provides safe access to a single tile slot in the ring buffer for consumers to read. It tracks the read index and automatically releases the slot when exiting the context. ## Fields * ​ring\_buffer\_ptr (`ConsumerTiles[origin, ring_buffer_type].RingBufferPtrType`): * ​read\_idx (`UInt32`): * ​a\_tile (`ConsumerTiles[origin, ring_buffer_type].ATile`): * ​b\_tile (`ConsumerTiles[origin, ring_buffer_type].BTile`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = ring_buffer_type.ATile` ### `BTile` `comptime BTile = ring_buffer_type.BTile` ### `RingBufferPtrType` `comptime RingBufferPtrType = Pointer[ring_buffer_type, origin]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[ring_buffer_type, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)`
--- ## ProducerTiles (Ring_buffer)
`@register_passable(trivial)` `struct ProducerTiles[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_consumers: Int, cluster_size: Int, tma_transfer: Bool, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer]]]` Context manager for producer access to ring buffer tiles. This struct provides safe access to a single tile slot in the ring buffer for the producer to fill. It automatically handles barrier synchronization when entering and exiting the context. ## Fields * ​ring\_buffer\_ptr (`ProducerTiles[origin, ring_buffer_type].RingBufferPtrType`): * ​barrier (`SMemBarrier`): * ​a\_tile (`ProducerTiles[origin, ring_buffer_type].ATile`): * ​b\_tile (`ProducerTiles[origin, ring_buffer_type].BTile`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = ring_buffer_type.ATile` ### `BTile` `comptime BTile = ring_buffer_type.BTile` ### `RingBufferPtrType` `comptime RingBufferPtrType = Pointer[ring_buffer_type, origin]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[ring_buffer_type, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` ### `__exit__` `__exit__(mut self)`
--- ## RingBuffer (3)
`struct RingBuffer[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_consumers: Int, cluster_size: Int, tma_transfer: Bool = True]` Ring buffer for managing pipeline synchronization between producers and consumers. This struct encapsulates the synchronization logic for a multi-stage pipeline with one producer and multiple consumers, supporting both single-block and multi-cluster configurations. The ring buffer uses two sets of barriers: * full\_mbar: Signals when tiles are ready for consumption * empty\_mbar: Signals when slots are available for production Template Parameters: a\_type: Data type for A matrix tiles b\_type: Data type for B matrix tiles a\_tile\_layout: Memory layout for A tiles b\_tile\_layout: Memory layout for B tiles num\_pipeline\_stages: Number of stages in the circular buffer num\_consumers: Number of consumer warp groups cluster\_size: Number of blocks in the cluster (1 for single-block) tma\_transfer: Whether the RingBuffer is used for TMA transfers (default: True) ## Fields * ​full\_mbar (`RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].PipelineBarrier`): * ​empty\_mbar (`RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].PipelineBarrier`): * ​read\_state (`PipelineState[num_pipeline_stages]`): * ​write\_state (`PipelineState[num_pipeline_stages]`): * ​warp\_group\_thread\_idx (`UInt`): * ​a\_tiles (`RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].ATileArray`): * ​b\_tiles (`RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].BTileArray`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].ATileArray.Tile` ### `ATileArray` `comptime ATileArray = SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, NVIDIASharedMemoryBasePtr.alignment]` ### `BTile` `comptime BTile = RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].BTileArray.Tile` ### `BTileArray` `comptime BTileArray = SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, NVIDIASharedMemoryBasePtr.alignment]` ### `PipelineBarrier` `comptime PipelineBarrier = SMemArrayType[SharedMemBarrier, num_pipeline_stages]` ### `SMM` `comptime SMM = SharedMemoryManager[NVIDIASharedMemoryBasePtr]` ## Methods ### `__init__` `__init__(out self, full_mbar: SMemArrayType[SharedMemBarrier, num_pipeline_stages], empty_mbar: SMemArrayType[SharedMemBarrier, num_pipeline_stages], warp_group_thread_idx: UInt, a_tiles: SMemTileArrayType[a_type, a_tile_layout, num_pipeline_stages, NVIDIASharedMemoryBasePtr.alignment], b_tiles: SMemTileArrayType[b_type, b_tile_layout, num_pipeline_stages, NVIDIASharedMemoryBasePtr.alignment])` Initialize ring buffer with barrier pointers. **Args:** * ​full\_mbar (`SMemArrayType`): Barrier array signaling when tiles are ready. * ​empty\_mbar (`SMemArrayType`): Barrier array signaling when slots are empty. * ​warp\_group\_thread\_idx ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Thread index within the warp group. * ​a\_tiles ([`SMemTileArrayType`](/mojo/kernels/linalg/structuring/SMemTileArrayType)): Iterator over A matrix tile storage. * ​b\_tiles ([`SMemTileArrayType`](/mojo/kernels/linalg/structuring/SMemTileArrayType)): Iterator over B matrix tile storage. ### `__enter__` `__enter__(mut self) -> Self` Context manager entry. ### `get_expected_bytes` `static get_expected_bytes() -> Int` Calculate expected bytes per pipeline stage for TMA transfers. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `get_slot` `get_slot(mut self) -> UInt32` Producer waits for empty buffer slot and prepares for loading. This method blocks until the current write slot is empty (all consumers have finished with it), then prepares the barrier for new data. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): Index of the available slot in the ring buffer. ### `get_producer_tiles` `get_producer_tiles(mut self) -> Tuple[SMemBarrier, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].ATile, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].BTile]` Get the next available slot for the producer to fill. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (barrier, a\_tile, b\_tile) for the producer to use. ### `enqueue_tile` `enqueue_tile(mut self)` Producer signals that tile loading is complete. This handles the specific signaling pattern needed: * For cp.async: Signal async copy arrival and barrier arrival * For TMA: Barrier arrival is handled by hardware After signaling, advances to the next pipeline stage. ### `get_tile` `get_tile(mut self) -> UInt32` Consumer waits for full buffer slot. This method blocks until the producer has filled the current read slot. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): Index of the available tile to consume. ### `get_consumer_tiles` `get_consumer_tiles(mut self) -> Tuple[UInt32, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].ATile, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer].BTile]` Consumer waits for full buffer slot and returns the tiles. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (read\_idx, a\_tile, b\_tile) for the consumer to process. ### `release_slot` `release_slot(mut self, read_idx: UInt32)` Consumer signals that buffer slot is empty. This allows the producer to reuse this slot in the ring buffer. Different arrival patterns are used for single-block vs multi-cluster. **Args:** * ​read\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Index of the slot to release. ### `consumer` `consumer(mut self) -> RingBufferConsumer[self, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer]]` Create a consumer view of this ring buffer. **Returns:** [`RingBufferConsumer`](/mojo/kernels/linalg/matmul/gpu/sm90/matmul_kernels/#ringbufferconsumer) ### `producer` `producer(mut self) -> RingBufferProducer[self, RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer]]` Create a producer view of this ring buffer. **Returns:** `RingBufferProducer` ### `arrive_empty_barriers` `arrive_empty_barriers(self)` Helper to arrive at empty barriers during consumer initialization. This is called when consumers enter their context to signal they are ready to consume tiles. It ensures all pipeline stages start with empty slots available for the producer.
--- ## RingBufferConsumer
`@register_passable(trivial)` `struct RingBufferConsumer[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_consumers: Int, cluster_size: Int, tma_transfer: Bool, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer]]]` Consumer view of the ring buffer. This struct provides the consumer interface to the ring buffer, allowing consumers to wait for and access tiles loaded by the producer. It handles the initial barrier arrival when entering the consumer context. ## Fields * ​ring\_buffer\_ptr (`RingBufferConsumer[origin, ring_buffer_type].RingBufferPtrType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = ring_buffer_type.ATile` ### `BTile` `comptime BTile = ring_buffer_type.BTile` ### `RingBufferPtrType` `comptime RingBufferPtrType = Pointer[ring_buffer_type, origin]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[ring_buffer_type, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` ### `get_tiles` `get_tiles(mut self) -> ConsumerTiles[origin, ring_buffer_type]` Get a context manager for accessing the next available tile. **Returns:** [`ConsumerTiles`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/ring_buffer/ConsumerTiles)
--- ## RingBufferProducer
`@register_passable(trivial)` `struct RingBufferProducer[a_type: DType, b_type: DType, a_tile_layout: Layout, b_tile_layout: Layout, num_pipeline_stages: Int, num_consumers: Int, cluster_size: Int, tma_transfer: Bool, //, origin: MutOrigin, ring_buffer_type: AnyStruct[RingBuffer[a_type, b_type, a_tile_layout, b_tile_layout, num_pipeline_stages, num_consumers, cluster_size, tma_transfer]]]` Producer view of the ring buffer. This struct provides the producer interface to the ring buffer, allowing the producer to wait for empty slots and fill them with new tiles. ## Fields * ​ring\_buffer\_ptr (`RingBufferProducer[origin, ring_buffer_type].RingBufferPtrType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ATile` `comptime ATile = ring_buffer_type.ATile` ### `BTile` `comptime BTile = ring_buffer_type.BTile` ### `RingBufferPtrType` `comptime RingBufferPtrType = Pointer[ring_buffer_type, origin]` ## Methods ### `__init__` `__init__(ring_buffer_ptr: Pointer[ring_buffer_type, origin]) -> Self` ### `__enter__` `__enter__(mut self) -> Self` ### `get_tiles` `get_tiles(mut self) -> ProducerTiles[origin, ring_buffer_type]` Get a context manager for accessing the next empty tile slot. **Returns:** [`ProducerTiles`](/mojo/kernels/linalg/matmul/gpu/sm100_structured/ring_buffer/ProducerTiles)
--- ## ring_buffer (3)
Ring buffer implementation for producer-consumer synchronization in GPU kernels. This module provides a ring buffer abstraction that enables efficient overlap of memory transfers and computation in matrix multiplication kernels. The pattern divides work between: * Producer: One warp group that loads tiles from global to shared memory * Consumers: Multiple warp groups that process tiles using tensor cores The ring buffer uses barrier synchronization to coordinate access to a circular queue of tile buffers, allowing the producer to work ahead while consumers process previously loaded data. Usage Example: \# Create ring buffer during kernel initialization var ring\_buffer = RingBuffer\[...]\(full\_mbar, empty\_mbar, ...) ``` # Producer pattern with ring_buffer.producer() as producer: while has_work(): with producer.get_tiles() as tiles: # Load data into tiles.a_tile and tiles.b_tile load_tile(tiles.a_tile, tiles.barrier) # Consumer pattern with ring_buffer.consumer() as consumer: while has_work(): with consumer.get_tiles() as tiles: # Process tiles.a_tile and tiles.b_tile gemm(tiles.a_tile, tiles.b_tile, output) ``` ## Structs * [​`ConsumerTiles`](./ConsumerTiles): Context manager for consumer access to ring buffer tiles. * [​`ProducerTiles`](./ProducerTiles): Context manager for producer access to ring buffer tiles. * [​`RingBuffer`](./RingBuffer): Ring buffer for managing pipeline synchronization between producers and consumers. * [​`RingBufferConsumer`](./RingBufferConsumer): Consumer view of the ring buffer. * [​`RingBufferProducer`](./RingBufferProducer): Producer view of the ring buffer.
--- ## testbed
## Functions * [​`test_matmul_sm90`](./test_matmul_sm90):
--- ## test_matmul_sm90
`test_matmul_sm90[a_type: DType, b_type: DType, c_type: DType, cluster_shape: IndexList[3], block_tile_shape: IndexList[3], wgmma_shape: IndexList[3], num_consumer: Int = 1, num_pipeline_stages: Int = 4, transpose_b: Bool = True, partitioned_multicast: Bool = False, grid_shape: OptionalReg[IndexList[2]] = None, use_tma_store: Bool = False, schedule: MatmulSchedule = MatmulSchedule.NONE, default_epilogue: Bool = False, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, measure_threshold: OptionalReg[Float64] = None, backend: Backend = Backend.CUBLAS](ctx: DeviceContext, m: ValOrDim[dim], n: ValOrDim[dim], k: ValOrDim[dim])`
--- ## TileLoader
Base trait for tile loading mechanisms in matrix multiplication. This trait defines the interface for loading tiles from global memory to shared memory, abstracting over different hardware mechanisms. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `load_tile` `load_tile(self: _Self, dst: LayoutTensor[_Self._dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], coords: Tuple[UInt, UInt])` Load a tile from global memory to shared memory. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tile in shared memory (must be 128-byte aligned). * ​mem\_barrier ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Memory barrier for synchronization. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile coordinates (row, column) in the source matrix.
--- ## TileLoaderCPAsync
`@register_passable(trivial)` `struct TileLoaderCPAsync[dtype: DType, src_layout: Layout, thread_layout: Layout, swizzle_mode: TensorMapSwizzle, vector_size: Int]` Software-based tile loader using cp.async instructions. This loader uses CUDA's cp.async instructions for asynchronous memory transfers with manual bounds checking and shared memory swizzling for optimal bank conflict avoidance. ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the elements being loaded. * ​src\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the source matrix in global memory. * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Thread arrangement for distributed copying. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Swizzling pattern for shared memory access. * ​vector\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements loaded per thread. ## Fields * ​src (`LayoutTensor[dtype, src_layout, MutAnyOrigin]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`TileLoader`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_loader/TileLoader), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(src: LayoutTensor[dtype, src_layout, MutAnyOrigin]) -> Self` Initialize the cp.async tile loader. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tensor in global memory. ### `load_tile` `load_tile(self, dst: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], coords: Tuple[UInt, UInt])` Load a tile using cp.async instructions. Extracts a tile from the source tensor and performs an asynchronous copy to shared memory with bounds checking and swizzling. Note: Unlike TMA, this method expects tile indices and handles the conversion to element offsets internally via the tile() method. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tile in shared memory. * ​mem\_barrier ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Memory barrier for synchronization (currently unused). * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile indices (row\_tile, col\_tile) in the source matrix.
--- ## TileLoaderTMA (Tile_loader)
`@register_passable(trivial)` `struct TileLoaderTMA[tma_origin: ImmutOrigin, dtype: DType, tile_layout: Layout, desc_layout: Layout, /, *, BK: UInt, cluster_size: Int32, use_partitioned_multicast: Bool]` TMA-based tile loader for hardware-accelerated memory transfers. This loader uses NVIDIA's Tensor Memory Accelerator (TMA) for efficient 2D tile transfers from global to shared memory, with optional multicast support for multi-block clusters. ## Parameters * ​tma\_origin ([`ImmutOrigin`](/mojo/stdlib/builtin/type_aliases/#immutorigin)): Origin type for the TMA operation. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the elements being loaded. * ​tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the complete tile in shared memory. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout described by the TMA descriptor (may be smaller). * ​BK ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Block size in the K dimension (for coordinate conversion). * ​cluster\_size ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Number of blocks in the cluster (1 for no clustering). * ​use\_partitioned\_multicast ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to use partitioned multicast loading. ## Fields * ​tma\_op (`TileLoaderTMA[tma_origin, dtype, tile_layout, desc_layout, BK=BK, cluster_size=cluster_size, use_partitioned_multicast=use_partitioned_multicast].TMATensorTilePtr`): * ​rank (`UInt`): * ​multicast\_mask (`UInt16`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`TileLoader`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_loader/TileLoader), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TMATensorTilePtr` `comptime TMATensorTilePtr = Pointer[TMATensorTile[dtype, tile_layout, desc_layout], tma_origin]` ## Methods ### `__init__` `__init__(tma_op: Pointer[TMATensorTile[dtype, tile_layout, desc_layout], tma_origin], rank: UInt, multicast_mask: UInt16) -> Self` Initialize the TMA tile loader. **Args:** * ​tma\_op ([`Pointer`](/mojo/stdlib/memory/pointer/Pointer)): Pointer to the TMA tensor descriptor. * ​rank ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Rank of this block within the cluster. * ​multicast\_mask ([`UInt16`](/mojo/stdlib/builtin/simd/#uint16)): Bit mask for multicast targets. ### `load_tile` `load_tile(self, dst: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], mem_barrier: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], _coords: Tuple[UInt, UInt])` Load a tile using TMA hardware acceleration. Converts tile indices to element coordinates and initiates a TMA transfer. For clusters, uses multicast to share data across blocks. Note: Coordinates are converted from (row, col) tile indices to (k\_elements, row/col\_elements) for TMA's K-major ordering. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tile in shared memory. * ​mem\_barrier ([`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)): Memory barrier for synchronization. * ​\_coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile coordinates (row\_tile\_idx, col\_tile\_idx).
--- ## async_copy_with_bound_check
`async_copy_with_bound_check[dtype: DType, src_layout: Layout, dst_layout: Layout, //, thread_layout: Layout, swizzle_mode: TensorMapSwizzle](src: LayoutTensor[dtype, src_layout, MutAnyOrigin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Helper function for cp.async with boundary checking. This method performs element-wise async copies with per-element boundary checking. Out-of-bounds accesses are automatically zero-filled, ensuring safe operation near matrix edges. The method also handles shared memory swizzling to avoid bank conflicts and maximize memory bandwidth utilization. Template Parameters: dtype: Data type of the elements. src\_layout: Layout of the source tile. dst\_layout: Layout of the destination tile. thread\_layout: Thread arrangement for distributed copying. swizzle\_mode: Swizzling pattern for bank conflict avoidance. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tensor fragment in global memory. * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor fragment in shared memory.
--- ## tile_loader (Tile_loader)
TileLoader module for efficient tile loading in GPU matrix multiplication. This module provides utilities for loading matrix tiles from global memory to shared memory using two different mechanisms: 1. TMA (Tensor Memory Accelerator): Hardware-accelerated loads that can efficiently transfer 2D tiles with multicast support for multi-block clusters. 2. cp.async: Software-based asynchronous copy instructions with manual bounds checking and swizzling for optimal shared memory access patterns. The TileLoader struct abstracts these loading mechanisms to provide a unified interface for the matmul kernel's producer threads. ## Structs * [​`TileLoaderCPAsync`](./TileLoaderCPAsync): Software-based tile loader using cp.async instructions. * [​`TileLoaderTMA`](./TileLoaderTMA): TMA-based tile loader for hardware-accelerated memory transfers. ## Traits * [​`TileLoader`](./TileLoader): Base trait for tile loading mechanisms in matrix multiplication. ## Functions * [​`async_copy_with_bound_check`](./async_copy_with_bound_check): Helper function for cp.async with boundary checking.
--- ## FragmentToSMemWriter
`@register_passable(trivial)` `struct FragmentToSMemWriter[c_type: DType, c_tile_layout: Layout, //, tile_n_size: Int, num_m_mmas: Int, num_consumer: Int, half_tile: Bool, WG_BM: Int, WG_BN: Int, sub_wg_id: Int]` Writes WGMMA accumulator results from registers to shared memory using st.matrix. Stores 16-byte fragments with swizzling to avoid bank conflicts. Sub-warp groups divide N-dimension work, each handling a portion of WG\_BN output tiles. ## Parameters * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Output data type (must be bfloat16 for st.matrix). * ​c\_tile\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the entire shared memory region. * ​tile\_n\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Width of each output tile (typically TMA\_BN). * ​num\_m\_mmas ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of MMA operations in M dimension. * ​num\_consumer ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of consumer warp groups. * ​half\_tile ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Special mode for handling partial tiles. * ​WG\_BM ([`Int`](/mojo/stdlib/builtin/int/Int)): Warp group tile height. * ​WG\_BN ([`Int`](/mojo/stdlib/builtin/int/Int)): Warp group tile width. * ​sub\_wg\_id ([`Int`](/mojo/stdlib/builtin/int/Int)): Which portion of WG\_BN this instance handles. ## Fields * ​c\_tile (`LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]`): * ​warp\_group\_thread\_idx (`UInt`): * ​local\_warp\_group\_idx (`UInt`): * ​st\_matrix\_rt\_layout (`RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`RegTileWriter`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_writer/RegTileWriter), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `st_matrix_layout` `comptime st_matrix_layout = Layout.row_major(WG_BM, tile_n_size)` ### `st_matrix_rt_layout_type` `comptime st_matrix_rt_layout_type = RuntimeLayout[st_matrix_n_layout[c_type, tile_n_size, num_m_mmas, num_consumer](), element_type=DType.int32, linear_idx_type=DType.int32]` ### `st_matrix_swizzle` `comptime st_matrix_swizzle = make_ldmatrix_swizzle[c_type, tile_n_size, log2_floor((16 // size_of[c_type]()))]()` ## Methods ### `__init__` `__init__(c_tile: LayoutTensor[c_type, c_tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128], warp_group_thread_idx: UInt, local_warp_group_idx: UInt) -> Self` Initialize the fragment writer. **Args:** * ​c\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Shared memory tile to write to. * ​warp\_group\_thread\_idx ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Thread index within the warp group. * ​local\_warp\_group\_idx ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Sub-warp group index (divides N-dimension work). ### `write_tile` `write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Write accumulator tile from registers to shared memory. **Args:** * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Register tile containing MMA results. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile position (row\_idx, col\_idx) in output.
--- ## RegTileWriter
Base trait for tile writing mechanisms in matrix multiplication. This trait defines the interface for writing register tiles to memory (either shared memory or global memory). ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `write_tile` `write_tile(self: _Self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Write a register tile to memory. **Args:** * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source register tile containing accumulator values. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile coordinates (row, column) in the destination matrix.
--- ## RegisterToGMemWriter
`@register_passable(trivial)` `struct RegisterToGMemWriter[c_type: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, wgmma_shape: IndexList[3], num_consumer: Int, N: Int, epilogue_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, check_runtime_bounds: Bool = False]` Writer for transferring accumulator registers directly to global memory. This writer handles the direct copy from register tiles to global memory tiles, with proper thread distribution and alignment. It supports optional epilogue processing, compute lambda transformations, and bounds checking. Note: At most one of epilogue\_fn or compute\_lambda\_fn should be set. ## Parameters * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Output data type. * ​dst\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the destination tensor. * ​dst\_address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): Address space of the destination tensor. * ​dst\_element\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Element layout of the destination tensor. * ​dst\_layout\_int\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Integer type for destination layout indices. * ​dst\_linear\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Linear index type for destination tensor. * ​dst\_masked ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the destination tensor is masked. * ​dst\_alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Alignment requirement for destination tensor. * ​wgmma\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Shape of the WGMMA operation \[M, N, K]. * ​num\_consumer ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of consumer warp groups. * ​N ([`Int`](/mojo/stdlib/builtin/int/Int)): Matrix N dimension. * ​epilogue\_fn ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional epilogue function (mutates value in place). * ​compute\_lambda\_fn ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional compute lambda function (returns new value). * ​check\_runtime\_bounds ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to perform bounds checking on N dimension. ## Fields * ​thread\_info (`ThreadInfo`): * ​dst (`RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds].DstType`): * ​num\_m\_mmas (`Int`): * ​tile\_coords (`OptionalReg[TileCoordinates]`): * ​max\_row (`OptionalReg[UInt32]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`RegTileWriter`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_writer/RegTileWriter), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `c_frag_size` `comptime c_frag_size = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) * wgmma_shape.__getitem__[3, DType.int64, Int](1)) // WARPGROUP_SIZE)` ### `DstType` `comptime DstType = LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]` ### `num_frag_mats` `comptime num_frag_mats = (RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds].num_n_frag_mat * RegisterToGMemWriter[wgmma_shape, num_consumer, N, epilogue_fn, compute_lambda_fn, check_runtime_bounds].num_m_frag_mat)` ### `num_m_frag_mat` `comptime num_m_frag_mat = ((wgmma_shape.__getitem__[3, DType.int64, Int](0) // 4) // 8)` ### `num_n_frag_mat` `comptime num_n_frag_mat = (wgmma_shape.__getitem__[3, DType.int64, Int](1) // 8)` ## Methods ### `__init__` `__init__(dst: LayoutTensor[c_type, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], warp_group_thread_idx: UInt, num_m_mmas: Int, tile_coords: OptionalReg[TileCoordinates] = None, max_row: OptionalReg[UInt32] = None) -> Self` Initialize the register-to-global-memory writer. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor in global memory. * ​warp\_group\_thread\_idx ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Thread index within the warp group. * ​num\_m\_mmas ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of MMA tiles in M dimension. * ​tile\_coords ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional tile coordinates for epilogue processing. * ​max\_row ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional maximum valid M coordinate (for epilogue). ### `write_tile` `write_tile(self, c_reg_tile: LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], coords: Tuple[UInt, UInt])` Write a single MMA tile from registers to global memory. **Args:** * ​c\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Register tile containing accumulator values. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile coordinates (row, column) in the destination matrix.
--- ## SMemTileWriter
Base trait for tile writing mechanisms in matrix multiplication. This trait defines the interface for writing tiles from shared memory to global memory, abstracting over different hardware mechanisms. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `write_tile` `write_tile(self: _Self, src: LayoutTensor[_Self._dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], coords: Tuple[UInt, UInt])` Write a tile from shared memory to global memory. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tile in shared memory (must be 128-byte aligned). * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile coordinates (row, column) in the destination matrix.
--- ## ThreadInfo
`@register_passable(trivial)` `struct ThreadInfo` Thread identification within the warp group. ## Fields * ​warp\_id (`UInt`): * ​lane\_id (`UInt`): * ​lane\_row (`UInt32`): * ​lane\_col (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(warp_id: UInt, lane_id: UInt, lane_row: UInt32, lane_col: UInt32) -> Self` ### `from_warp_group_idx` `static from_warp_group_idx(warp_group_thread_idx: UInt) -> Self` Create ThreadInfo from a warp group thread index. **Args:** * ​warp\_group\_thread\_idx ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Thread index within the warp group. **Returns:** `Self`: ThreadInfo struct with computed warp\_id, lane\_id, lane\_row, and lane\_col.
--- ## TileCoordinates
`@register_passable(trivial)` `struct TileCoordinates` Helper struct for managing tile coordinate offsets. This struct encapsulates corner and split coordinates used in epilogue processing and provides a clean interface for coordinate transformations. ## Fields * ​corner (`IndexList[2]`): * ​split (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(corner: IndexList[2], split: IndexList[2]) -> Self` Initialize tile coordinates. **Args:** * ​corner ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Corner coordinates offset. * ​split ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Split coordinates offset. ### `adjust` `adjust(self, base_coords: IndexList[2]) -> IndexList[2]` Add corner and split offsets to base coordinates. **Args:** * ​base\_coords ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Base tile coordinates. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): Adjusted coordinates with corner and split offsets applied.
--- ## TileWriterTMA
`@register_passable(trivial)` `struct TileWriterTMA[tma_origin: ImmutOrigin, dtype: DType, tma_layout: Layout, desc_layout: Layout, //]` TMA-based tile writer for hardware-accelerated memory transfers. This writer uses NVIDIA's Tensor Memory Accelerator (TMA) for efficient 2D tile transfers from shared to global memory. ## Parameters * ​tma\_origin ([`ImmutOrigin`](/mojo/stdlib/builtin/type_aliases/#immutorigin)): Origin type for the TMA operation. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the elements being written. * ​tma\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the TMA tile for async store operations. * ​desc\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout described by the TMA descriptor. ## Fields * ​tma\_op (`TileWriterTMA.TMATensorTilePtr`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`SMemTileWriter`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_writer/SMemTileWriter), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `TMATensorTilePtr` `comptime TMATensorTilePtr = Pointer[TMATensorTile[dtype, tma_layout, desc_layout], tma_origin]` ## Methods ### `__init__` `__init__(tma_op: Pointer[TMATensorTile[dtype, tma_layout, desc_layout], tma_origin]) -> Self` Initialize the TMA tile writer. **Args:** * ​tma\_op ([`Pointer`](/mojo/stdlib/memory/pointer/Pointer)): Pointer to the TMA tensor descriptor. ### `write_tile` `write_tile(self, src: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], coords: Tuple[UInt, UInt])` Write a tile using TMA hardware acceleration. Performs an asynchronous TMA store from shared memory to global memory. The operation includes proper fencing and synchronization. Note: Coordinates are expected in (N, M) order for column-major output. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tile in shared memory. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile coordinates (col, row) in element space.
--- ## TileWriterThreadwise
`@register_passable(trivial)` `struct TileWriterThreadwise[dtype: DType, dst_layout: Layout, dst_address_space: AddressSpace, dst_element_layout: Layout, dst_layout_int_type: DType, dst_linear_idx_type: DType, dst_masked: Bool, dst_alignment: Int, //, thread_layout: Layout, simd_size: Int, half_tile: Bool = False]` ## Fields * ​dst (`TileWriterThreadwise[thread_layout, simd_size, half_tile].DstType`): * ​thread\_idx (`UInt`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`SMemTileWriter`](/mojo/kernels/linalg/matmul/gpu/sm90/tile_writer/SMemTileWriter), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `DstType` `comptime DstType = LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment]` ## Methods ### `__init__` `__init__(dst: LayoutTensor[dtype, dst_layout, MutAnyOrigin, address_space=dst_address_space, element_layout=dst_element_layout, layout_int_type=dst_layout_int_type, linear_idx_type=dst_linear_idx_type, masked=dst_masked, alignment=dst_alignment], thread_idx: UInt) -> Self` Initialize the threadwise tile writer. **Args:** * ​dst ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination tensor in global memory. * ​thread\_idx ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Thread index within the consumer warp group. ### `write_tile` `write_tile(self, src: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=128], coords: Tuple[UInt, UInt])` Write a tile using thread-distributed stores. Each thread writes a portion of the tile with proper swizzling for optimal memory access patterns. **Args:** * ​src ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source tile in shared memory. * ​coords ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tile indices (row\_tile, col\_tile) in the destination matrix.
--- ## tile_writer (Tile_writer)
TileWriter module for efficient tile writing in GPU matrix multiplication. This module provides utilities for writing tiles to memory using different mechanisms and destinations: 1. Register → Shared Memory: Uses st.matrix hardware instruction for efficient storage of WGMMA accumulator results to shared memory with swizzling. 2. Register → Global Memory: Direct stores from register tiles to global memory with optional epilogue processing and bounds checking. 3. Shared Memory → Global Memory: Hardware-accelerated TMA stores or regular stores for efficient 2D tile transfers from shared to global memory. Two main traits abstract these writing mechanisms: * TileWriter: For shared memory → global memory transfers * RegTileWriter: For register → memory (shared or global) transfers ## Structs * [​`FragmentToSMemWriter`](./FragmentToSMemWriter): Writes WGMMA accumulator results from registers to shared memory using st.matrix. * [​`RegisterToGMemWriter`](./RegisterToGMemWriter): Writer for transferring accumulator registers directly to global memory. * [​`ThreadInfo`](./ThreadInfo): Thread identification within the warp group. * [​`TileCoordinates`](./TileCoordinates): Helper struct for managing tile coordinate offsets. * [​`TileWriterThreadwise`](./TileWriterThreadwise): * [​`TileWriterTMA`](./TileWriterTMA): TMA-based tile writer for hardware-accelerated memory transfers. ## Traits * [​`RegTileWriter`](./RegTileWriter): Base trait for tile writing mechanisms in matrix multiplication. * [​`SMemTileWriter`](./SMemTileWriter): Base trait for tile writing mechanisms in matrix multiplication.
--- ## TuningConfigSM90
`@register_passable(trivial)` `struct TuningConfigSM90` ## Fields * ​M (`Int`): * ​N (`Int`): * ​K (`Int`): * ​mma\_shape (`IndexList[3]`): * ​block\_tile\_shape (`IndexList[3]`): * ​num\_pipeline\_stages (`UInt`): * ​cluster\_shape (`IndexList[3]`): * ​num\_consumer (`UInt`): * ​partitioned\_multicast (`Bool`): * ​grid\_shape (`OptionalReg[IndexList[2]]`): * ​schedule (`MatmulSchedule`): * ​splits (`OptionalReg[Int]`): * ​raster\_order (`OptionalReg[RasterOrder]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`TuningConfig`](/mojo/kernels/internal_utils/dispatch_utils/TuningConfig), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(M: Int, N: Int, K: Int, mma_shape: IndexList[3], block_tile_shape: IndexList[3], num_pipeline_stages: UInt, cluster_shape: IndexList[3], num_consumer: UInt, partitioned_multicast: Bool, grid_shape: OptionalReg[IndexList[2]] = None, schedule: MatmulSchedule = MatmulSchedule.NONE, splits: OptionalReg[Int] = None, raster_order: OptionalReg[RasterOrder] = None) -> Self` ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String)
--- ## tuning_configs (Tuning_configs)
## Structs * [​`TuningConfigSM90`](./TuningConfigSM90):
--- ## split_k_reduce
`split_k_reduce[c_type: DType, work_space_type: DType, c_layout: Layout, work_space_layout: Layout, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, origin], work_space: LayoutTensor[work_space_type, work_space_layout, origin], ctx: DeviceContext)`
--- ## MatmulSchedule
`@register_passable(trivial)` `struct MatmulSchedule` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `DS_SCHEDULER` `comptime DS_SCHEDULER = MatmulSchedule(3)` ### `NONE` `comptime NONE = MatmulSchedule(0)` ### `TILE1D` `comptime TILE1D = MatmulSchedule(1)` ### `TILE2D` `comptime TILE2D = MatmulSchedule(2)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## RasterOrder (Tile_scheduler)
`@register_passable(trivial)` `struct RasterOrder` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AlongM` `comptime AlongM = RasterOrder(1)` ### `AlongN` `comptime AlongN = RasterOrder(0)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)` ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)`
--- ## TileScheduler (5)
`@register_passable(trivial)` `struct TileScheduler[problem_shape: IndexList[3], tile_shape: IndexList[3], grid_shape: IndexList[2], cluster: IndexList[3] = Index(1, 1, 1), raster_dim: UInt32 = 1, schedule: MatmulSchedule = MatmulSchedule.TILE2D]` ## Fields * ​idx (`UInt32`): * ​prob\_shape (`IndexList[3]`): * ​num\_waves\_m (`UInt32`): * ​num\_waves\_n (`UInt32`): * ​log\_num\_waves\_n (`FastDiv[DType.uint32]`): * ​current\_iter (`Int`): * ​num\_aligned\_m\_blocks (`UInt32`): * ​num\_blocks (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `kNum1DBlocksPerGroup` `comptime kNum1DBlocksPerGroup = 16` ### `kNumNBlocks` `comptime kNumNBlocks = SIMD[DType.uint32, 1](ceildiv(problem_shape.__getitem__[3, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](1)))` ### `num_grids` `comptime num_grids = SIMD[DType.uint32, 1]((grid_shape.__getitem__[2, DType.int64, Int](0) * grid_shape.__getitem__[2, DType.int64, Int](1)))` ### `wave_shape` `comptime wave_shape = Index[dtype=DType.uint32]((tile_shape.__getitem__[3, DType.int64, Int](0) * grid_shape.__getitem__[2, DType.int64, Int](1)), (tile_shape.__getitem__[3, DType.int64, Int](1) * grid_shape.__getitem__[2, DType.int64, Int](0)))` ## Methods ### `__init__` `__init__(prob_shape: IndexList[3]) -> Self` ### `get_current_work_info` `get_current_work_info(mut self) -> WorkInfo` **Returns:** `WorkInfo` ### `advance` `advance(mut self)` ### `fetch_next_work` `fetch_next_work(mut self) -> WorkInfo` **Returns:** `WorkInfo` ### `num_output_tiles` `num_output_tiles(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `fetch_next_work_ds` `fetch_next_work_ds(mut self) -> WorkInfo` **Returns:** `WorkInfo`
--- ## WorkInfo (5)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​m (`UInt32`): * ​n (`UInt32`): * ​k\_start (`UInt32`): * ​num\_k\_tiles (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `INVALID_WORK_INFO` `comptime INVALID_WORK_INFO = WorkInfo(0, 0, 0, 0, False)` ## Methods ### `__init__` `__init__() -> Self` ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `is_final_split` `is_final_split(self, k_tiles_per_output_tile: UInt32) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `get_k_start` `get_k_start(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## tile_scheduler (3)
## Structs * [​`MatmulSchedule`](./MatmulSchedule): * [​`RasterOrder`](./RasterOrder): * [​`TileScheduler`](./TileScheduler): * [​`WorkInfo`](./WorkInfo):
--- ## ReductionMode
`@register_passable(trivial)` `struct ReductionMode` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Deterministic` `comptime Deterministic = ReductionMode(0)` ### `Nondeterministic` `comptime Nondeterministic = ReductionMode(1)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## SplitKTileScheduler
`@register_passable(trivial)` `struct SplitKTileScheduler[problem_shape_nk: IndexList[2], tile_shape: IndexList[3], splits: UInt32, num_consumer: UInt32, num_pipeline_stages: UInt32, cluster_shape: IndexList[2], raster_order: RasterOrder, reduction_mode: ReductionMode = ReductionMode.Deterministic]` ## Fields * ​prob\_shape (`IndexList[3]`): * ​block\_id\_in\_cluster (`IndexList[2]`): * ​blocks\_per\_problem (`UInt32`): * ​current\_work\_linear\_idx (`UInt32`): * ​log\_cluster\_shape\_major (`UInt32`): * ​log\_cluster\_shape\_minor (`UInt32`): * ​cluster\_blk\_major (`UInt32`): * ​locks\_ptr (`LegacyUnsafePointer[Int32]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `k_tiles_per_output_tile` `comptime k_tiles_per_output_tile = ceildiv(problem_shape_nk.__getitem__[2, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](2))` ### `k_tiles_per_split` `comptime k_tiles_per_split = splits.__rfloordiv__[DType.uint32, 1](SIMD[DType.uint32, 1](ceildiv(problem_shape_nk.__getitem__[2, DType.int64, Int](1), tile_shape.__getitem__[3, DType.int64, Int](2))))` ### `log_cluster_size` `comptime log_cluster_size = log2_floor((cluster_shape.__getitem__[2, DType.int64, Int](0) * cluster_shape.__getitem__[2, DType.int64, Int](1)))` ### `WorkTileType` `comptime WorkTileType[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin]` #### Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): ## Methods ### `__init__` `__init__(prob_shape: IndexList[3], block_id_in_cluster: IndexList[2], locks_ptr: LegacyUnsafePointer[UInt8]) -> Self` ### `get_sm_num` `get_sm_num(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_problem_blocks_shape` `static get_problem_blocks_shape(problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `initial_work_tile_info` `initial_work_tile_info(mut self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `get_current_work_info` `get_current_work_info(mut self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `get_worktile_m_n_idx` `get_worktile_m_n_idx(mut self, mut work_tile_info: WorkInfo, linear_tile_id: UInt32)` ### `assign_work` `assign_work(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32)` ### `get_k_start_and_linear_tile_id` `get_k_start_and_linear_tile_id(mut self, mut work_tile_info: WorkInfo, linear_idx: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `fetch_next_work` `fetch_next_work(mut self, mut work_tile_info: WorkInfo) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/linalg/matmul/gpu/tile_scheduler/WorkInfo) ### `requires_reduction` `requires_reduction(self, work_tile_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `advance_to_next_work` `advance_to_next_work(mut self)` ### `is_last_split` `is_last_split(self, work_tile_info: WorkInfo) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `get_grid_shape` `static get_grid_shape(dyn_cluster_shape: IndexList[3], dyn_raster_order: RasterOrder = RasterOrder.AlongN) -> IndexList[3]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `get_num_tiles` `static get_num_tiles(problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `get_required_locks_buffer_size_bytes` `static get_required_locks_buffer_size_bytes[accum_type: DType, dyn_num_consumer: UInt32](problem_shape: IndexList[3], dyn_tile_shape: IndexList[3], dyn_cluster_shape: IndexList[2]) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `get_linear_idx_from_m_and_n` `get_linear_idx_from_m_and_n(self, tile_m: UInt32, tile_n: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `output_tile_index` `output_tile_index(self, work_tile_info: WorkInfo) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `reduction` `reduction[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], work_tile_info: WorkInfo, num_barriers: UInt32, warp_group_local_idx: UInt32)` ### `wait_eq` `static wait_eq(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, val: UInt32)` ### `wait_lt` `static wait_lt(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, count: UInt32)` ### `arrive_set` `static arrive_set(lock_ptr: LegacyUnsafePointer[Int32], barrier_id: Int32, barrier_group_thread_idx: Int, lock_idx: UInt32, increment: UInt32)` ### `store_accumulator` `store_accumulator[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)` ### `reduce_add` `reduce_add[accum_type: DType, c_reg_layout: Layout, workspace_layout: Layout, //, *, write_back: Bool](self, reduction_workspace: LayoutTensor[accum_type, workspace_layout, MutAnyOrigin], c_reg_tile: LayoutTensor[accum_type, c_reg_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL], reduction_tile_idx: UInt32, warp_group_local_idx: UInt32, warp_group_thread_idx: UInt32)`
--- ## tile_scheduler_splitk (3)
## Structs * [​`ReductionMode`](./ReductionMode): * [​`SplitKTileScheduler`](./SplitKTileScheduler):
--- ## matmul (6)
Provides the backend implementation for matmuls. ## Packages * [​`cpu`](./cpu/): Provides the CPU backend implementations for matmuls. * [​`gpu`](./gpu/): * [​`vendor`](./vendor/): Provides the Vendor backend implementations for matmuls. ## Functions * [​`matmul`](./matmul):
--- ## matmul (7)
`matmul[transpose_a: Bool = False, transpose_b: Bool = False, b_packed: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, _trace_description: StringSlice[StaticConstantOrigin] = "", target: StringSlice[StaticConstantOrigin] = "cpu"](c: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: Optional[DeviceContext])` `matmul[transpose_a: Bool = False, transpose_b: Bool = False, b_packed: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, _trace_description: StringSlice[StaticConstantOrigin] = "", target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], ctx: DeviceContextPtr = DeviceContextPtr())` `matmul[transpose_a: Bool = False, transpose_b: Bool = False, b_packed: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, elementwise_compute_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, _trace_description: StringSlice[StaticConstantOrigin] = "", target: StringSlice[StaticConstantOrigin] = "cpu"](c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], ctx: Optional[DeviceContext])`
--- ## Backend
`@register_passable(trivial)` `struct Backend` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AUTOMATIC` `comptime AUTOMATIC = Backend(0)` ### `CUBLAS` `comptime CUBLAS = Backend(1)` ### `CUBLASLT` `comptime CUBLASLT = Backend(2)` ### `HIPBLASLT` `comptime HIPBLASLT = Backend(4)` ### `ROCBLAS` `comptime ROCBLAS = Backend(3)` ## Methods ### `__init__` `__init__(value: Int) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__is__` `__is__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__isnot__` `__isnot__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__int__` `__int__(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## Handle
`struct Handle[backend: Backend = _resolve_backend[Backend.AUTOMATIC]()]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ### `resolved_backend` `comptime resolved_backend = _resolve_backend[backend]()` ### `type` `comptime type = Variant[LegacyUnsafePointer[cublasContext], Handle, hipblasLtHandle_t]` ## Methods ### `__init__` `__init__(out self)` ### `__is__` `__is__(self, other: Backend) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__isnot__` `__isnot__(self, other: Backend) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__enter__` `__enter__(self) -> Self` ### `__exit__` `__exit__(mut self)`
--- ## blas
## Structs * [​`Backend`](./Backend): * [​`Handle`](./Handle): ## Functions * [​`matmul`](./matmul): Matmul using the vendor BLAS library. With a global handle.
--- ## matmul (Blas)
`matmul[use_tf32: Bool = False, *, scales_type: DType = DType.invalid, a_scales_layout: Layout = Layout.row_major(-1), b_scales_layout: Layout = Layout.row_major(-1)](ctx: DeviceContext, c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], *, a_scales: OptionalReg[LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin]] = None, b_scales: OptionalReg[LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin]] = None, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)` Matmul using the vendor BLAS library. With a global handle. `matmul[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, *, use_tf32: Bool = False, scales_type: DType = DType.invalid, a_scales_layout: Layout = Layout.row_major(-1), b_scales_layout: Layout = Layout.row_major(-1)](ctx: DeviceContext, c_tensor: LayoutTensor[c_type, c_layout, origin], a_tensor: LayoutTensor[a_type, a_layout, origin], b_tensor: LayoutTensor[b_type, b_layout, origin], *, a_scales: OptionalReg[LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin]] = None, b_scales: OptionalReg[LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin]] = None, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)` `matmul[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, use_tf32: Bool = False, scales_type: DType = DType.invalid, a_scales_layout: Layout = Layout.row_major(-1), b_scales_layout: Layout = Layout.row_major(-1)](ctx: DeviceContext, handle: Handle[backend], c_tensor: LayoutTensor[c_type, c_layout, origin], a_tensor: LayoutTensor[a_type, a_layout, origin], b_tensor: LayoutTensor[b_type, b_layout, origin], *, a_scales: OptionalReg[LayoutTensor[scales_type, a_scales_layout, MutAnyOrigin]] = None, b_scales: OptionalReg[LayoutTensor[scales_type, b_scales_layout, MutAnyOrigin]] = None, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)` `matmul[use_tf32: Bool = False](ctx: DeviceContext, handle: Handle[backend], c: NDBuffer[dtype, 2, origin, shape], a: NDBuffer[dtype, 2, origin, shape], b: NDBuffer[dtype, 2, origin, shape], *, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0)`
--- ## vendor (Vendor)
Provides the Vendor backend implementations for matmuls. This backend is used for testing and evaluation. ## Modules * [​`blas`](./blas/): * [​`matmul`](./matmul/):
--- ## matmul (8)
## Functions * [​`matmul`](./matmul): This implements the matmul kernel for the Blackwell architecture. Note that we do not currently have pure mojo kernels which would utilize blackwell architectures, so in place we just call the CUBLAS library.
--- ## matmul (9)
`matmul[c_type: DType, a_type: DType, b_type: DType, //, transpose_b: Bool = False, elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None, config: OptionalReg[MatmulConfig[a_type, b_type, c_type, transpose_b]] = None](c: NDBuffer[c_type, 2, origin, shape], a: NDBuffer[a_type, 2, origin, shape], b: NDBuffer[b_type, 2, origin, shape], ctx: DeviceContext)` This implements the matmul kernel for the Blackwell architecture. Note that we do not currently have pure mojo kernels which would utilize blackwell architectures, so in place we just call the CUBLAS library.
--- ## matrix_band_part
The module implements matrix band part functions. ## Functions * [​`matrix_band_part`](./matrix_band_part):
--- ## matrix_band_part (Matrix_band_part)
`matrix_band_part[dtype: DType, int_type: DType, cond_type: DType, rank: Int, input_0_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], simd_width: Int, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = "cpu"](input_shape: IndexList[rank], num_lower: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_upper: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], exclude: LayoutTensor[cond_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)`
--- ## BTileGenerator
`struct BTileGenerator[mut: Bool, //, config: KernelConfig, a_type: DType, b_type: DType, c_type: DType, shape: DimList, transpose_b: Bool, b_packed: Bool, origin: Origin[mut]]` Struct to encapsulate a tile of B that supports prepacking. If b\_packed is true, calls to get\_tile will return a buffer view from B. Otherwise, calls to get\_tile will copy a tile from B into a stack allocated scratch buffer and return a view of that. ## Fields * ​b (`NDBuffer[b_type, 2, origin, shape]`): * ​b\_tile\_stack\_ptr (`LegacyUnsafePointer[Scalar[b_type]]`): * ​tile\_n\_k (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `get` `static get(b: NDBuffer[b_type, 2, origin, shape], tile_n_k: IndexList[2]) -> Self` ### `get_tile` `get_tile[inner_size: Int](self, global_offset: GemmShape, tile_dim_nk: IndexList[2], valid_data_dim_nk: IndexList[2]) -> NDBuffer[b_type, 3, MutAnyOrigin, config.packed_shape]` Get a packed matrix (B) tile. valid\_data\_tile\_nk is ignored for pre-packing, where the tile is padded to have shape of tile\_dim\_nk. **Args:** * ​global\_offset ([`GemmShape`](/mojo/kernels/linalg/utils/GemmShape)): Offset in the global M, N, K dimensions. * ​tile\_dim\_nk ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Tile shape based on cache size and matrix dimensions. * ​valid\_data\_dim\_nk ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The upper bounds for N and K dimensions. **Returns:** [`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer): A view of the packed tile.
--- ## PackMatrixCols
`struct PackMatrixCols[original_mut: Bool, //, original_shape: DimList, packed_shape: DimList, dtype: DType, simd_size: Int, column_inner_size: Int, use_vnni: Bool, use_i8mm: Bool, packed_origin: MutOrigin, original_origin: Origin[original_mut]]` Pack columns from a matrix into the mlas packed layout and extract inner vectors of columns into the packed inner dimension, e.g. extracts \[X, Y] and packs as \[Yo]\[X]\[Yi]. ## Fields * ​packed\_matrix (`NDBuffer[dtype, 3, packed_origin, packed_shape]`): * ​original\_matrix (`NDBuffer[dtype, 2, original_origin, original_shape]`): * ​global\_offset (`IndexList[2]`): * ​pack\_tile\_dim (`IndexList[2]`): * ​valid\_data\_dim (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `run` `static run(packed_matrix: NDBuffer[dtype, 3, MutAnyOrigin, packed_shape], original_matrix: NDBuffer[dtype, 2, MutAnyOrigin, original_shape], global_offset: IndexList[2], pack_tile_dim: IndexList[2], valid_data_dim: IndexList[2])` Interface function to run the packing routine. Args: packed\_matrix(NDBuffer): pre-allocated buffer space for packed data. original\_matrix(NDBuffer): data buffer containing the original matrix to pack. global\_offset(IndexList): offset to use when indexing the original matrix. pack\_tile\_dim(IndexList): 2D dimension tuple describing the size of the packed tile. valid\_data\_dim(IndexList): 2D dimension tuple describing the amount of valid data on the global buffer starting from the offset.
--- ## PackMatrixRows
`struct PackMatrixRows[original_mut: Bool, //, original_shape: DimList, packed_shape: DimList, dtype: DType, simd_size: Int, row_inner_size: Int, packed_origin: MutOrigin, original_origin: Origin[original_mut]]` Pack rows from a matrix into the mlas packed layout and extract inner vectors of rows into the packed inner dimension, e.g. extract tile \[X, Y] and pack into \[Xo]\[Y]\[Xi]. ## Fields * ​packed\_matrix (`NDBuffer[dtype, 3, packed_origin, packed_shape]`): * ​original\_matrix (`NDBuffer[dtype, 2, original_origin, original_shape]`): * ​global\_offset (`IndexList[2]`): * ​pack\_tile\_dim (`IndexList[2]`): * ​valid\_data\_dim (`IndexList[2]`): * ​valid\_simd\_dim (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `run` `static run(packed_matrix: NDBuffer[dtype, 3, packed_origin, packed_shape], original_matrix: NDBuffer[dtype, 2, original_origin, original_shape], global_offset: IndexList[2], pack_tile_dim: IndexList[2], valid_data_dim: IndexList[2])` Interface function to run the packing routine. Args: packed\_matrix(NDBuffer): pre-allocated buffer space for packed data. original\_matrix(NDBuffer): data buffer containing the original matrix to pack. global\_offset(IndexList): offset to use when indexing the original matrix. pack\_tile\_dim(IndexList): 2D dimension tuple describing the size of the packed tile. valid\_data\_dim(IndexList): 2D dimension tuple describing the amount of valid data on the global buffer starting from the offset.
--- ## packing
## Structs * [​`BTileGenerator`](./BTileGenerator): Struct to encapsulate a tile of B that supports prepacking. * [​`PackMatrixCols`](./PackMatrixCols): Pack columns from a matrix into the mlas packed layout and extract inner vectors of columns into the packed inner dimension, e.g. extracts \[X, Y] and packs as \[Yo]\[X]\[Yi]. * [​`PackMatrixRows`](./PackMatrixRows): Pack rows from a matrix into the mlas packed layout and extract inner vectors of rows into the packed inner dimension, e.g. extract tile \[X, Y] and pack into \[Xo]\[Y]\[Xi]. ## Functions * [​`pack_b`](./pack_b): Utility function to pack the entire B matrix, such that each \[tile\_n // inner\_size, tile\_k, inner\_size] tile of src is contiguous in dst. * [​`pack_b_ndbuffer`](./pack_b_ndbuffer): * [​`pack_matmul_b_shape_func`](./pack_matmul_b_shape_func): * [​`pack_transposed_b_ndbuffer`](./pack_transposed_b_ndbuffer):
--- ## pack_b
`pack_b[transpose_b: Bool, simd_size: Int, inner_size: Int, a_type: DType, b_type: DType, c_type: DType, src_shape: DimList, dst_shape: DimList](dst: NDBuffer[b_type, 2, origin, dst_shape], src: NDBuffer[b_type, 2, origin, src_shape], tile_n: Int, tile_k: Int)` Utility function to pack the entire B matrix, such that each \[tile\_n // inner\_size, tile\_k, inner\_size] tile of src is contiguous in dst. Tiles (not tile contents) are stored in row major order, so tile\[i, j] is tile\_n \* tile\_k bytes away from tile\[i, j+1].
--- ## pack_b_ndbuffer
`pack_b_ndbuffer[b_mut: Bool, //, a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, c_type: DType, c_shape: DimList, b_origin: Origin[b_mut], output_origin: MutOrigin](b_input: NDBuffer[b_type, 2, b_origin, b_shape], output_buffer: NDBuffer[b_type, 2, output_origin])`
--- ## pack_matmul_b_shape_func
`pack_matmul_b_shape_func[a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, c_type: DType, c_shape: DimList, transpose_in_0: Bool, single_thread_blocking_override: Bool](b_input: NDBuffer[b_type, 2, origin, b_shape]) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## pack_transposed_b_ndbuffer
`pack_transposed_b_ndbuffer[a_type: DType, a_shape: DimList, b_type: DType, b_shape: DimList, c_type: DType, c_shape: DimList](b_input: NDBuffer[b_type, 2, origin, b_shape], output_buffer: NDBuffer[b_type, 2, origin])`
--- ## apply_q
`apply_q[dtype: DType, element_layout: Layout](sigma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], A: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], X: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Applies the implicit Q factor stored in `A` and `sigma` after calling `qr_factorization` to the `X` matrix. See `qr_factorization` for more details on the construction of the Householder reflector.
--- ## form_q
`form_q[dtype: DType, element_layout: Layout](sigma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], A: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], Q: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Forms the Q factor from the implicit Q factor stored in `A` and `sigma` after calling `qr_factorization` and stores the result in `Q`.
--- ## qr_factorization
## Functions * [​`apply_q`](./apply_q): Applies the implicit Q factor stored in `A` and `sigma` after calling `qr_factorization` to the `X` matrix. * [​`form_q`](./form_q): Forms the Q factor from the implicit Q factor stored in `A` and `sigma` after calling `qr_factorization` and stores the result in `Q`. * [​`qr_factorization`](./qr_factorization): Performs QR factorization of a matrix `A` using the Householder reflector method.
--- ## qr_factorization (Qr_factorization)
`qr_factorization[dtype: DType, element_layout: Layout](sigma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], A: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Performs QR factorization of a matrix `A` using the Householder reflector method. This function computes the QR factorization of matrix `A` in-place using Householder reflections. The result is stored directly in the input matrix `A`, with scaling factors in `sigma`. The implementation follows the LAPACK algorithm for generating Householder reflectors in-place. Algorithm: The Householder reflector is defined as: U = I - σww^H where: w = (x + νe₁)/ξ σ = ξ/ν ξ = x₀ + ν ν = sign(x₀)‖x‖₂ ``` This ensures that U^H x = -νe₁ and U^H U = I. ``` References: \[1] Lehoucq, R. B. (1996). The computation of elementary unitary matrices. ACM Transactions on Mathematical Software, 22(4), 393-400. Note: There is a typo in reference \[lawn72]. The correct result is U^H x = -νe₁.
--- ## IteratorScatterGatherAmd
`struct IteratorScatterGatherAmd[thread_layout: Layout, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1]` Iterator-based AMD scatter-gather for DRAM-register data movement. ## Parameters * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Thread organization layout. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total threads (defaults to thread\_layout size). * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Thread execution scope (block or warp). * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of block dimensions. ## Fields * ​buffer (`AMDBufferResource`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], tensor_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked])` Initialize with tensor and iterator. **Args:** * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Layout tensor for bounds. * ​tensor\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): Iterator for AMD buffer resource. ### `copy` `copy(self, dst_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_gmem_tile_iter: LayoutTensorIter[dtype, layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked])` Copy DRAM to registers via iterator. **Args:** * ​dst\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination register tile. * ​src\_gmem\_tile\_iter ([`LayoutTensorIter`](/mojo/kernels/layout/layout_tensor/LayoutTensorIter)): Source memory iterator.
--- ## NVIDIASharedMemoryBasePtr
`struct NVIDIASharedMemoryBasePtr[name: StringSlice[StaticConstantOrigin] = "extern_ptr_syml", memory_alignment: Int = 8]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`SharedMemoryBasePtr`](/mojo/kernels/linalg/structuring/SharedMemoryBasePtr), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `alignment` `comptime alignment = 128` ## Methods ### `ptr` `static ptr() -> LegacyUnsafePointer[Int8, address_space=AddressSpace.SHARED]` **Returns:** `LegacyUnsafePointer`
--- ## SMemArrayType
`@register_passable(trivial)` `struct SMemArrayType[type: AnyTrivialRegType, size: Int]` Shared memory array of fixed size. ## Parameters * ​type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Element type. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements. ## Fields * ​ptr (`SMemArrayType[type, size].ptr_type`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ptr_type` `comptime ptr_type = LegacyUnsafePointer[type, address_space=AddressSpace.SHARED]` ### `storage_size` `comptime storage_size = (size * size_of[type]())` ## Methods ### `__init__` `__init__(unsafe_ptr: LegacyUnsafePointer[type, address_space=AddressSpace.SHARED]) -> Self` Initialize with shared memory pointer. **Args:** * ​unsafe\_ptr (`LegacyUnsafePointer`): Shared memory pointer. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> SMemArrayType[type, size].ptr_type` Get a pointer to the element at index. **Args:** * ​index (`T`): Element index. **Returns:** `SMemArrayType`: Pointer to element. ### `len` `static len() -> Int` Get array length in bytes. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Total size in bytes. ### `stack_allocation` `static stack_allocation[alignment: Int = align_of[type]()]() -> Self`
--- ## SMemTileArrayType
`@register_passable(trivial)` `struct SMemTileArrayType[dtype: DType, layout: Layout, num_tiles: Int, alignment: Int]` Array of tiles in shared memory. ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Tile data type. * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Tile layout configuration. * ​num\_tiles ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of tiles. * ​alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Memory alignment. ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `storage_size` `comptime storage_size = ((layout.size() * size_of[dtype]()) * num_tiles)` ### `Tile` `comptime Tile = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment]` ## Methods ### `__init__` `__init__[mut: Bool, //, origin: Origin[mut]](unsafe_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin]) -> Self` Initialize with shared memory pointer. **Args:** * ​unsafe\_ptr (`LegacyUnsafePointer`): Shared memory pointer. ### `__getitem__` `__getitem__[T: Intable](self, index: T) -> SMemTileArrayType[dtype, layout, num_tiles, alignment].Tile` Get tile at index. **Args:** * ​index (`T`): Tile index. **Returns:** `SMemTileArrayType`: Tile at index. ### `stack_allocation` `static stack_allocation() -> Self`
--- ## ScatterGatherAmd
`struct ScatterGatherAmd[thread_layout: Layout, num_threads: Int = thread_layout.size(), thread_scope: ThreadScope = ThreadScope.BLOCK, block_dim_count: Int = 1]` AMD tile-based scatter-gather for DRAM-register data movement. ## Parameters * ​thread\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Thread organization layout. * ​num\_threads ([`Int`](/mojo/stdlib/builtin/int/Int)): Total threads (defaults to thread\_layout size). * ​thread\_scope ([`ThreadScope`](/mojo/kernels/layout/layout_tensor/ThreadScope)): Thread execution scope (block or warp). * ​block\_dim\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of block dimensions. ## Fields * ​buffer (`AMDBufferResource`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Initialize with a tensor. **Args:** * ​tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Layout tensor for AMD buffer resource creation. ### `copy` `copy(self, dst_reg_tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_gmem_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], offset: OptionalReg[UInt] = None)` Copy DRAM to registers. **Args:** * ​dst\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination register tile. * ​src\_gmem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source global memory tile. * ​offset ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional copy offset. `copy(self, dst_gmem_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src_reg_tile: LayoutTensor[dtype, layout, origin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Copy registers to DRAM. **Args:** * ​dst\_gmem\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Destination global memory tile. * ​src\_reg\_tile ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Source register tile.
--- ## SharedMemoryBasePtr
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `alignment` `comptime alignment` ## Required methods ### `ptr` `static ptr() -> LegacyUnsafePointer[Int8, address_space=AddressSpace.SHARED]` **Returns:** `LegacyUnsafePointer`
--- ## SharedMemoryManager
`struct SharedMemoryManager[SMBP: SharedMemoryBasePtr]` ## Fields * ​base\_ptr (`LegacyUnsafePointer[Int8, address_space=AddressSpace.SHARED]`): * ​offset (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `Array` `comptime Array[type: AnyTrivialRegType, size: Int] = SMemArrayType[type, size]` #### Parameters * ​type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): ### `Tile` `comptime Tile[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=SMBP.alignment]` #### Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): ### `TileArray` `comptime TileArray[dtype: DType, layout: Layout, num_tiles: Int] = SMemTileArrayType[dtype, layout, num_tiles, SMBP.alignment]` #### Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): * ​num\_tiles ([`Int`](/mojo/stdlib/builtin/int/Int)): ## Methods ### `__init__` `__init__(out self)` Initialize the shared memory manager. ### `build` `build[dtype: DType, layout: Layout, //, T: AnyStruct[LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=SMBP.alignment]]](mut self) -> T` Allocate a single tile. **Returns:** `T`: Allocated tile. `build[dtype: DType, layout: Layout, num_tiles: Int, //, T: AnyStruct[SMemTileArrayType[dtype, layout, num_tiles, SMBP.alignment]]](mut self) -> T` Allocate a tile array. **Returns:** `T`: Allocated tile array. `build[type: AnyTrivialRegType, size: Int, //, T: AnyStruct[SMemArrayType[type, size]]](mut self) -> T` Allocate a regular array. **Returns:** `T`: Allocated array.
--- ## structuring
## `comptime` values ### `eval` `comptime eval[T: AnyType, //, val: T] = val` Helper alias to force evaluation of expressions at compile time. #### Parameters * ​T ([`AnyType`](/stdlib/builtin/anytype/AnyType)): * ​val (`T`): ### `NVIDIASharedMemoryManager` `comptime NVIDIASharedMemoryManager = SharedMemoryManager[NVIDIASharedMemoryBasePtr]` ### `PipelineBarrier` `comptime PipelineBarrier[num_pipeline_stages: Int] = SMemArrayType[SharedMemBarrier, num_pipeline_stages]` Type alias for shared memory pipeline barrier array. #### Parameters * ​num\_pipeline\_stages ([`Int`](/stdlib/builtin/int/Int)): ### `RegTileType` `comptime RegTileType[_dtype: DType, layout: Layout, /, *, element_layout: Layout = Layout(IntTuple(1), IntTuple(1)), layout_int_type: DType = _get_layout_type(layout, AddressSpace.LOCAL), linear_idx_type: DType = _get_index_type(layout, AddressSpace.LOCAL), masked: Bool = False, alignment: Int = align_of[_dtype]()] = LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for register (local memory) tile tensors. #### Parameters * ​\_dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): * ​element\_layout ([`Layout`](/kernels/layout/layout/Layout)): * ​layout\_int\_type ([`DType`](/stdlib/builtin/dtype/DType)): * ​linear\_idx\_type ([`DType`](/stdlib/builtin/dtype/DType)): * ​masked ([`Bool`](/stdlib/builtin/bool/Bool)): * ​alignment ([`Int`](/stdlib/builtin/int/Int)): ### `SMemBarrier` `comptime SMemBarrier = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` Type alias for shared memory barrier pointer. ### `SMemPtr` `comptime SMemPtr[type: AnyTrivialRegType] = LegacyUnsafePointer[type, address_space=AddressSpace.SHARED]` #### Parameters * ​type (`AnyTrivialRegType`): ### `SMemTileIter` `comptime SMemTileIter[dtype: DType, layout: Layout] = LayoutTensorIter[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=128]` #### Parameters * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ### `SMemTileType` `comptime SMemTileType[_dtype: DType, layout: Layout, /, *, element_layout: Layout = Layout(IntTuple(1), IntTuple(1)), layout_int_type: DType = _get_layout_type(layout, AddressSpace.SHARED), linear_idx_type: DType = _get_index_type(layout, AddressSpace.SHARED), masked: Bool = False, alignment: Int = align_of[_dtype]()] = LayoutTensor[_dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` Type alias for shared memory tile tensors. #### Parameters * ​\_dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): * ​element\_layout ([`Layout`](/kernels/layout/layout/Layout)): * ​layout\_int\_type ([`DType`](/stdlib/builtin/dtype/DType)): * ​linear\_idx\_type ([`DType`](/stdlib/builtin/dtype/DType)): * ​masked ([`Bool`](/stdlib/builtin/bool/Bool)): * ​alignment ([`Int`](/stdlib/builtin/int/Int)): ## Structs * [​`IteratorScatterGatherAmd`](./IteratorScatterGatherAmd): Iterator-based AMD scatter-gather for DRAM-register data movement. * [​`NVIDIASharedMemoryBasePtr`](./NVIDIASharedMemoryBasePtr): * [​`ScatterGatherAmd`](./ScatterGatherAmd): AMD tile-based scatter-gather for DRAM-register data movement. * [​`SharedMemoryManager`](./SharedMemoryManager): * [​`SMemArrayType`](./SMemArrayType): Shared memory array of fixed size. * [​`SMemTileArrayType`](./SMemTileArrayType): Array of tiles in shared memory. ## Traits * [​`SharedMemoryBasePtr`](./SharedMemoryBasePtr):
--- ## transpose
The module implements Transpose functions. ## Functions * [​`transpose`](./transpose): Permute the axis of `input` based on `perms`, and place the result in `output`. * [​`transpose_2d`](./transpose_2d): * [​`transpose_3d_swap_inner`](./transpose_3d_swap_inner): * [​`transpose_3d_swap_outer`](./transpose_3d_swap_outer): * [​`transpose_4d_swap_middle`](./transpose_4d_swap_middle): * [​`transpose_inplace`](./transpose_inplace): * [​`transpose_strided`](./transpose_strided): * [​`transpose_trivial_memcpy`](./transpose_trivial_memcpy):
--- ## transpose (Transpose)
`transpose[rank: Int, dtype: DType, //](output: NDBuffer[dtype, rank, origin, shape], input: NDBuffer[dtype, rank, origin, shape], perms: LegacyUnsafePointer[Scalar[DType.index]])` Permute the axis of `input` based on `perms`, and place the result in `output`. Example: ```mojo transpose(output, input, [2, 0, 1]) # guarantees output[x, y, z] = input[z, x, y] ``` **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of input and output buffers. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of buffer elements. **Args:** * ​output ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): The output buffer. * ​input ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): The input buffer. * ​perms (`LegacyUnsafePointer`): Permutation of the input axes.
--- ## transpose_2d
`transpose_2d[rank: Int, output_shape: DimList, input_shape: DimList, dtype: DType](output: NDBuffer[dtype, rank, origin, output_shape], input: NDBuffer[dtype, rank, origin, input_shape], perms: LegacyUnsafePointer[Scalar[DType.index]], simplified_input_shape: IndexList[rank], simplified_rank: Int, offset: Int)`
--- ## transpose_3d_swap_inner
`transpose_3d_swap_inner[rank: Int, dtype: DType, //](output: NDBuffer[dtype, rank, origin, shape], input: NDBuffer[dtype, rank, origin, shape], perms: LegacyUnsafePointer[Scalar[DType.index]], simplified_input_shape: IndexList[rank], simplified_rank: Int)`
--- ## transpose_3d_swap_outer
`transpose_3d_swap_outer[rank: Int, output_shape: DimList, input_shape: DimList, dtype: DType](output: NDBuffer[dtype, rank, origin, output_shape], input: NDBuffer[dtype, rank, origin, input_shape], perms: LegacyUnsafePointer[Scalar[DType.index]], simplified_input_shape: IndexList[rank], simplified_rank: Int)`
--- ## transpose_4d_swap_middle
`transpose_4d_swap_middle[rank: Int, dtype: DType, //](output: NDBuffer[dtype, rank, origin, shape], input: NDBuffer[dtype, rank, origin, shape, strides], perms: LegacyUnsafePointer[Scalar[DType.index]], simplified_input_shape: IndexList[rank], simplified_rank: Int)`
--- ## transpose_inplace
`transpose_inplace[rows: Int, cols: Int, dtype: DType](buf: NDBuffer[dtype, 2, origin, DimList.__init__[Int, Int](rows, cols)])` `transpose_inplace[rows: Int, cols: Int, dtype: DType](buf: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## transpose_strided
`transpose_strided[rank: Int, dtype: DType, //](output: NDBuffer[dtype, rank, origin, shape], input: NDBuffer[dtype, rank, origin, shape], perms: LegacyUnsafePointer[Scalar[DType.index]])`
--- ## transpose_trivial_memcpy
`transpose_trivial_memcpy[rank: Int, output_shape: DimList, input_shape: DimList, dtype: DType](output: NDBuffer[dtype, rank, origin, output_shape], input: NDBuffer[dtype, rank, origin, input_shape])`
--- ## GemmShape
`@register_passable(trivial)` `struct GemmShape` Helper class to unpack gemm dimension and layout. ## Fields * ​M (`Int`): * ​N (`Int`): * ​K (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(index: IndexList[3]) -> Self` Constructor of a gemm shape record from a index tuple. **Args:** * ​index ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The int tuple containing the index(m,n,k). ### `__getitem__` `__getitem__(self, idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `__setitem__` `__setitem__(mut self, idx: Int, value: Int)` ### `__add__` `__add__(self, rhs: Self) -> Self` Coordinate-wise addition of two gemm shape records. **Args:** * ​rhs (`Self`): Another gemm shape record to add with. ### `__sub__` `__sub__(self, rhs: Self) -> Self` Coordinate-wise subtraction of two gemm shape records. **Args:** * ​rhs (`Self`): Another gemm shape record to subtract with. ### `get` `static get[transpose_b: Bool](c: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], a: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive], b: NDBuffer[dtype, 2, origin, shape, strides, alignment2=alignment2, address_space=address_space, exclusive=exclusive]) -> Self` Constructor of a gemm shape record from input buffers. M, N, and K are intentionally calculated using `a` and `c` ONLY. This is because `b` may be padded to a multiple of the tile size if it has been pre-packed. **Args:** * ​c ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): NDBuffer with allocated output space. * ​a ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): NDBuffer containing matrix operand A. * ​b ([`NDBuffer`](/mojo/kernels/buffer/buffer/NDBuffer)): NDBuffer containing matrix operand B. `static get[transpose_b: Bool, layout_c: Layout, layout_a: Layout, layout_b: Layout](c: LayoutTensor[dtype, layout_c, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout_a, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout_b, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> Self` Constructor of a gemm shape record from input buffers. M, N, and K are intentionally calculated using `a` and `c` ONLY. This is because `b` may be padded to a multiple of the tile size if it has been pre-packed. **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor with allocated output space. * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor containing matrix operand A. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): LayoutTensor containing matrix operand B. ### `as_index` `as_index(self) -> IndexList[3]` Utility to convert the underlying data to an index tuple. So that the utilities such as elementwise add can be used. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The constructed index tuple.
--- ## InnerKernelID
`@register_passable(trivial)` `struct InnerKernelID` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `DEFAULT` `comptime DEFAULT = InnerKernelID(0)` ### `I8MM` `comptime I8MM = InnerKernelID(3)` ### `NEON` `comptime NEON = InnerKernelID(2)` ### `VNNI` `comptime VNNI = InnerKernelID(1)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## KernelConfig (Utils)
`struct KernelConfig` Static configuration of the matmul inner kernel. ## Fields * ​kernel\_rows (`Int`): * ​kernel\_cols (`Int`): * ​simd\_size (`Int`): * ​packed\_shape (`DimList`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, *, kernel_rows: Int, kernel_cols: Int, simd_size: Int, packed_shape: DimList)`
--- ## MicroKernelShape
`@register_passable(trivial)` `struct MicroKernelShape` Record describing the inner kernel shape. ## Fields * ​simd\_rows (`Int`): * ​simd\_cols (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(rows: Int, cols: Int) -> Self`
--- ## SubMatmulConfig
`struct SubMatmulConfig` Static configuration of sub-matrices in parallel matmul. ## Fields * ​offset (`IndexList[3]`): * ​shape (`IndexList[3]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## apply_epilogue
`apply_epilogue[elementwise_lambda: elementwise_epilogue_type, dst_layout: Layout, dst_element_layout: Layout = Layout(IntTuple(1), IntTuple(1))](src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], offset: Int)`
--- ## calculate_tile_n_k
`calculate_tile_n_k[a_type: DType, b_type: DType, c_type: DType, kernel_cols: Int](n: Int, k: Int) -> IndexList[2]` Helper heuristic function to decide on tile size to partition the matmul given the cache size and desired data layout. **Parameters:** * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the A tensor. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the B tensor. * ​c\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the C tensor. * ​kernel\_cols ([`Int`](/mojo/stdlib/builtin/int/Int)): The umber of columns of the micro kernel. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The calculated tile size to partition the matmul as (TileN, TileK). `calculate_tile_n_k[a_type: DType, b_type: DType, c_type: DType, kernel_cols: Int](global_tile_shape: GemmShape) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## dispatch_get_kernel_type
`dispatch_get_kernel_type[func: fn[x: Bool]() raises capturing -> None](m: Int, n: Int, k: Int)` `dispatch_get_kernel_type[func: fn[x: Bool]() capturing -> None](m: Int, n: Int, k: Int)`
--- ## get_kernel_config
`get_kernel_config[a_type: DType, b_type: DType, c_type: DType, *, kernel_type: Bool = False]() -> KernelConfig` Utility function to extract matmul configuration parameters for exported Functions. TODO: Add target dependent configuration parameters. **Returns:** [`KernelConfig`](/mojo/kernels/linalg/utils/KernelConfig)
--- ## get_kernel_type
`get_kernel_type(m: Int, n: Int, k: Int) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## get_matmul_arch_factor
`get_matmul_arch_factor[use_vnni: Bool, use_i8mm: Bool]() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_matmul_kernel_shape
`get_matmul_kernel_shape[a_type: DType, b_type: DType, c_type: DType, kernel_type: Bool]() -> MicroKernelShape` **Returns:** `MicroKernelShape`
--- ## get_matmul_kernel_shape_ARM
`get_matmul_kernel_shape_ARM[a_type: DType, b_type: DType, c_type: DType, kernel_type: Bool]() -> MicroKernelShape` **Returns:** `MicroKernelShape`
--- ## get_matmul_kernel_shape_x86
`get_matmul_kernel_shape_x86[kernel_type: Bool]() -> MicroKernelShape` **Returns:** `MicroKernelShape`
--- ## get_matmul_num_tasks
`get_matmul_num_tasks[a_type: DType, b_type: DType, c_type: DType, simd_size: Int, kernel_type: Bool](m: Int, n: Int, k: Int, max_num_tasks: Int) -> Int` Compute the number of tasks for parallel matmul. The max number of tasks is typically the number of threads/cores. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_matmul_prefetch_b_distance_k
`get_matmul_prefetch_b_distance_k() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_min_task_size
`get_min_task_size() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_packB_unroll_factor
`get_packB_unroll_factor() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_pack_data_size
`get_pack_data_size[dtype: DType]() -> Int` Utility to compute the number of elements to pack in each tile. Returns: The number of elements to pack. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_partitioned_matmul
`get_partitioned_matmul[a_type: DType, b_type: DType, c_type: DType, kernel_rows: Int, kernel_cols: Int](m: Int, n: Int, k: Int, task_id: Int, num_tasks: Int) -> SubMatmulConfig` **Returns:** `SubMatmulConfig`
--- ## get_partitioned_matmul_mojo
`get_partitioned_matmul_mojo[b_type: DType, kernel_rows: Int, kernel_cols: Int, use_i8mm: Bool = False](m: Int, n: Int, k: Int, task_id: Int, num_tasks: Int) -> SubMatmulConfig` **Returns:** `SubMatmulConfig`
--- ## get_partitioned_matmul_mojo_shape
`get_partitioned_matmul_mojo_shape[b_type: DType, kernel_rows: Int, kernel_cols: Int, use_i8mm: Bool](m: Int, n: Int, k: Int, num_tasks: Int) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## utils
## `comptime` values ### `elementwise_compute_lambda_type` `comptime elementwise_compute_lambda_type = fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> SIMD[dtype, width]` ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None` ## Structs * [​`GemmShape`](./GemmShape): Helper class to unpack gemm dimension and layout. * [​`InnerKernelID`](./InnerKernelID): * [​`KernelConfig`](./KernelConfig): Static configuration of the matmul inner kernel. * [​`MicroKernelShape`](./MicroKernelShape): Record describing the inner kernel shape. * [​`SubMatmulConfig`](./SubMatmulConfig): Static configuration of sub-matrices in parallel matmul. ## Functions * [​`apply_epilogue`](./apply_epilogue): * [​`calculate_tile_n_k`](./calculate_tile_n_k): Helper heuristic function to decide on tile size to partition the matmul given the cache size and desired data layout. * [​`dispatch_get_kernel_type`](./dispatch_get_kernel_type): * [​`get_kernel_config`](./get_kernel_config): Utility function to extract matmul configuration parameters for exported Functions. TODO: Add target dependent configuration parameters. * [​`get_kernel_type`](./get_kernel_type): * [​`get_matmul_arch_factor`](./get_matmul_arch_factor): * [​`get_matmul_kernel_shape`](./get_matmul_kernel_shape): * [​`get_matmul_kernel_shape_ARM`](./get_matmul_kernel_shape_ARM): * [​`get_matmul_kernel_shape_x86`](./get_matmul_kernel_shape_x86): * [​`get_matmul_num_tasks`](./get_matmul_num_tasks): Compute the number of tasks for parallel matmul. The max number of tasks is typically the number of threads/cores. * [​`get_matmul_prefetch_b_distance_k`](./get_matmul_prefetch_b_distance_k): * [​`get_min_task_size`](./get_min_task_size): * [​`get_pack_data_size`](./get_pack_data_size): Utility to compute the number of elements to pack in each tile. Returns: The number of elements to pack. * [​`get_packB_unroll_factor`](./get_packB_unroll_factor): * [​`get_partitioned_matmul`](./get_partitioned_matmul): * [​`get_partitioned_matmul_mojo`](./get_partitioned_matmul_mojo): * [​`get_partitioned_matmul_mojo_shape`](./get_partitioned_matmul_mojo_shape): * [​`packA_i8mm`](./packA_i8mm): * [​`partition_work`](./partition_work): * [​`select_inner_kernel`](./select_inner_kernel): * [​`use_i8mm_fn`](./use_i8mm_fn): * [​`use_vnni_fn`](./use_vnni_fn):
--- ## packA_i8mm
`packA_i8mm[a_type: DType](t0: Int, t1: Int, k: Int, a_ptr: LegacyUnsafePointer[Scalar[a_type]], a_packed_ptr: LegacyUnsafePointer[Scalar[a_type]])`
--- ## partition_work
`partition_work(task_id: Int, num_tasks: Int, work: Int, work_block_size: Int) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## select_inner_kernel
`select_inner_kernel[a_type: DType, b_type: DType, c_type: DType]() -> InnerKernelID` **Returns:** `InnerKernelID`
--- ## use_i8mm_fn
`use_i8mm_fn[a_type: DType, b_type: DType, c_type: DType]() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## use_vnni_fn
`use_vnni_fn[a_type: DType, b_type: DType, c_type: DType]() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## MatmulConfig (Utils_gpu)
`@register_passable(trivial)` `struct MatmulConfig[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]` Static configuration of GPU matmul. ## Fields * ​block\_tile\_shape (`IndexList[3]`): * ​warp\_tile\_shape (`IndexList[3]`): * ​mma\_shape (`IndexList[3]`): * ​num\_pipeline\_stages (`UInt`): * ​num\_k\_partitions (`UInt`): * ​k\_group\_size (`UInt`): * ​num\_warp\_k\_partitions (`UInt`): * ​cluster\_shape (`IndexList[3]`): * ​num\_consumer (`UInt`): * ​partitioned\_multicast (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ACCUM_PRECISION` `comptime ACCUM_PRECISION = 1` ### `accum_type` `comptime accum_type = get_accum_type[a_type]()` ### `OUTPUT_PRECISION` `comptime OUTPUT_PRECISION = 2` ### `split_k_reduction_scheme` `comptime split_k_reduction_scheme = env_get_int["SPLITK_REDUCTION_SCHEME", 2]()` ### `split_k_reduction_type` `comptime split_k_reduction_type = c_type if (env_get_int["SPLITK_REDUCTION_SCHEME", 2]() == 2) else MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type` ## Methods ### `__init__` `__init__(*, block_tile_shape: IndexList[3] = Index(128, 128, 32), warp_tile_shape: IndexList[3] = Index(64, 64, 32), mma_shape: IndexList[3] = get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), cluster_shape: IndexList[3] = Index(1, 1, 1), num_pipeline_stages: UInt = 4, num_k_partitions: UInt = 1, k_group_size: UInt = 1, num_warp_k_partitions: UInt = 1, num_consumer: UInt = 1, partitioned_multicast: Bool = False, pdl_level: PDLLevel = PDLLevel()) -> Self` ### `__eq__` `__eq__(self, rhs: MatmulConfig[a_type, b_type, c_type, transpose_b]) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `copy_field` `copy_field(mut self, other: MatmulConfig[a_type, b_type, c_type, transpose_b])` ### `swapAB` `swapAB(self) -> MatmulConfig[b_type, a_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig) ### `num_warps_m` `num_warps_m(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `num_warps_n` `num_warps_n(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `num_threads` `num_threads(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `shared_mem_usage` `shared_mem_usage(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `grid_dim` `grid_dim(self, m: UInt, n: UInt) -> IndexList[3]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `block_dim` `block_dim(self) -> IndexList[3]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `work_space_size` `work_space_size(self, M: UInt, N: UInt) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `pdl_level` `pdl_level(self) -> PDLLevel` **Returns:** [`PDLLevel`](/mojo/stdlib/gpu/primitives/grid_controls/PDLLevel) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)` ### `__repr__` `__repr__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with the underlying bytes. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance.
--- ## MatmulKernels
`@register_passable(trivial)` `struct MatmulKernels[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False]` Supported matmul kernels. The configurations are named as: **. BK, mma shape, and warp tile shape are decided internally. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ampere_128x128_4` `comptime ampere_128x128_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 128, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), 4, 1, 1, 1, 1, False, PDLLevel())` ### `ampere_256x128_3` `comptime ampere_256x128_3 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 256, (2 * _bk_base[a_type]())), Index(64, 64, (2 * _bk_base[a_type]())), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), 3, 1, 1, 1, 1, False, PDLLevel())` ### `ampere_256x64_4` `comptime ampere_256x64_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(64, 256, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), 4, 1, 1, 1, 1, False, PDLLevel())` ### `hopper_128x128_4` `comptime hopper_128x128_4 = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(128, 128, _bk_base[a_type]()), Index(64, 64, _bk_base[a_type]()), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), 4, 1, 1, 1, 1, False, PDLLevel())` ### `tuning_config` `comptime tuning_config = MatmulConfig[a_type, b_type, c_type, transpose_b](Index(env_get_int["TUNE_BM", 128](), env_get_int["TUNE_BN", 128](), env_get_int["TUNE_BK", 32]()), Index(env_get_int["TUNE_WM", 64](), env_get_int["TUNE_WN", 64](), env_get_int["TUNE_BK", 32]()), get_mma_shape[a_type, MatmulConfig[a_type, b_type, c_type, transpose_b].accum_type](), Index(1, 1, 1), UInt(env_get_int["TUNE_NUM_STAGES", 4]()), UInt(env_get_int["TUNE_NUM_K_PARTITIONS", 1]()), 1, UInt(env_get_int["TUNE_NUM_WARP_K_PARTITIONS", 1]()), 1, False, PDLLevel())`
--- ## block_swizzle
`block_swizzle(block_idx: IndexList[2, element_type=element_type], grid_dim: IndexList[2, element_type=element_type]) -> IndexList[2, element_type=element_type]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## create_hilbert_lut
`create_hilbert_lut(ctx: DeviceContext, grid_x: Int, grid_y: Int) -> DeviceBuffer[DType.uint32]` Precompute Hilbert-curve block swizzle lookup-table for a rectangular grid. The returned device pointer refers to a 1-D UInt32 array of length grid\_x \* grid\_y. For linear (row-major) block id `id`, the packed value at `lut[id]` encodes the swizzled coordinates: upper 16-bits = y, lower 16-bits = x. **Returns:** `DeviceBuffer`
--- ## get_hilbert_lut_with_cache
`get_hilbert_lut_with_cache(ctx: DeviceContext, grid_x: Int, grid_y: Int) -> DeviceBuffer[DType.uint32]` Get Hilbert lookup table using global cache (no struct needed). **Returns:** `DeviceBuffer`
--- ## utils_gpu
## Structs * [​`MatmulConfig`](./MatmulConfig): Static configuration of GPU matmul. * [​`MatmulKernels`](./MatmulKernels): Supported matmul kernels. ## Functions * [​`block_swizzle`](./block_swizzle): * [​`create_hilbert_lut`](./create_hilbert_lut): Precompute Hilbert-curve block swizzle lookup-table for a rectangular grid. * [​`get_hilbert_lut_with_cache`](./get_hilbert_lut_with_cache): Get Hilbert lookup table using global cache (no struct needed). * [​`select_config`](./select_config):
--- ## select_config
`select_config[a_type: DType, b_type: DType, c_type: DType, transpose_b: Bool = False](M: Int, N: Int, K: Int, ctx: DeviceContext) -> MatmulConfig[a_type, b_type, c_type, transpose_b]` **Returns:** [`MatmulConfig`](/mojo/kernels/linalg/utils_gpu/MatmulConfig)
--- ## elu
`elu[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]` Compute the Elu Op using the equation $z if z >= 0 else alpha*(e^z -1)$. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to compute the ELU operation on. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The result of the ELU operation.
--- ## activations
The module contains implementations of activation functions. ## Functions * [​`elu`](./elu): Compute the Elu Op using the equation $z if z >= 0 else alpha*(e^z -1)$. * [​`leaky_relu`](./leaky_relu): Compute the Leaky ReLU using the equation $max(0, x) + negative_slope * min(0, x)$. * [​`relu`](./relu): Compute the Relu Op using the equation $max(0, x)$. * [​`relu_n1`](./relu_n1): Compute the Relu N1 Op using the equation $max(min(x,1),-1)$. * [​`sign`](./sign): Compute the sign (0, 1) of the input value.
--- ## leaky_relu
`leaky_relu[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width], negative_slope: Scalar[dtype]) -> SIMD[dtype, simd_width]` Compute the Leaky ReLU using the equation $max(0, x) + negative_slope * min(0, x)$. **Constraints:** Type must be a floating point Dtype. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to compute the Leaky ReLU operation on. * ​negative\_slope ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The slope for negative values. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The result of the Leaky ReLU operation.
--- ## relu
`relu[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]` Compute the Relu Op using the equation $max(0, x)$. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to compute the RELU operation on. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The result of the RELU operation.
--- ## relu_n1
`relu_n1[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]` Compute the Relu N1 Op using the equation $max(min(x,1),-1)$. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to compute the RELU N1 operation on. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The result of the RELU N1 operation.
--- ## sign
`sign[dtype: DType, simd_width: Int](x: SIMD[dtype, simd_width]) -> SIMD[dtype, simd_width]` Compute the sign (0, 1) of the input value. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType used for the computation. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to compute the sign operation on. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The result of the sign operation.
--- ## arange
`arange[dtype: DType, simd_width: Int](start: Scalar[dtype], stop: Scalar[dtype], step: Scalar[dtype], index: IndexList[1]) -> SIMD[dtype, simd_width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## arange_shape
`arange_shape[dtype: DType, single_thread_blocking_override: Bool](start: Scalar[dtype], stop: Scalar[dtype], step: Scalar[dtype]) -> IndexList[1]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## arange (Arange)
## Functions * [​`arange`](./arange): * [​`arange_shape`](./arange_shape):
--- ## arg_nonzero
`arg_nonzero[dtype: DType, output_type: DType](input_buffer: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_buffer: LayoutTensor[output_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Gather the indices of all non-zero elements in input buffer storing the indices in the output\_buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The element dtype. * ​output\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer dtype to store the indices in. **Args:** * ​input\_buffer ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to count the non-zeros in. * ​output\_buffer ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The indices of all non-zero elements.
--- ## arg_nonzero_shape
`arg_nonzero_shape[dtype: DType, single_thread_blocking_override: Bool](input_buffer: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[2]` Return \[NumNonZeros, InputRank] where NumNonZeros are the number of non-zero elements in the input. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The element dtype. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): This op can block. **Args:** * ​input\_buffer ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to count the non-zeros in. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): Shape of the arg\_nonzero kernel for this input \[NumNonZeros, InputRank].
--- ## arg_nonzero (Arg_nonzero)
## Functions * [​`arg_nonzero`](./arg_nonzero): Gather the indices of all non-zero elements in input buffer storing the indices in the output\_buffer. * [​`arg_nonzero_shape`](./arg_nonzero_shape): Return \[NumNonZeros, InputRank] where NumNonZeros are the number of non-zero elements in the input.
--- ## argmax
`argmax(input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int, output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Finds the indices of the maximum element along the specified axis. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor. `argmax(input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis_buf: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Finds the indices of the maximum element along the specified axis. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​axis\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The axis tensor. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The axis tensor.
--- ## argmin
`argmin(input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int, output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Finds the indices of the minimum element along the specified axis. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor. `argmin(input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis_buf: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Finds the indices of the minimum element along the specified axis. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​axis\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The axis tensor. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The axis tensor.
--- ## argmaxmin
## Functions * [​`argmax`](./argmax): Finds the indices of the maximum element along the specified axis. * [​`argmin`](./argmin): Finds the indices of the minimum element along the specified axis.
--- ## argmax_gpu
`argmax_gpu[dtype: DType, output_type: DType](ctx: DeviceContext, input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## argmaxmin_gpu
`argmaxmin_gpu[dtype: DType, output_type: DType, largest: Bool](ctx: DeviceContext, input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Wraps the Top-K GPU kernel with K=1 to perform argmax on the inner-most dimension. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - The data dtype of the input tensor. * ​output\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - The data dtype of the output tensor. * ​largest ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - Whether to perform argmax or argmin.
--- ## argmin_gpu
`argmin_gpu[dtype: DType, output_type: DType](ctx: DeviceContext, input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## argmaxmin_gpu (Argmaxmin_gpu)
## Functions * [​`argmax_gpu`](./argmax_gpu): * [​`argmaxmin_gpu`](./argmaxmin_gpu): Wraps the Top-K GPU kernel with K=1 to perform argmax on the inner-most dimension. * [​`argmin_gpu`](./argmin_gpu):
--- ## argsort
`argsort[*, ascending: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)` Performs argsort on input buffer, storing indices in output buffer. **Parameters:** * ​ascending ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Sort direction (True for ascending, False for descending). * ​target (`StringSlice`): Target device ("cpu" or "gpu"). **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Buffer to store sorted indices. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Buffer containing values to sort. * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): Device context for execution. `argsort[ascending: Bool = True](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` CPU-only version of argsort. **Parameters:** * ​ascending ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Sort direction (True for ascending, False for descending). **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Buffer to store sorted indices. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Buffer containing values to sort.
--- ## argsort (Argsort)
## Functions * [​`argsort`](./argsort): Performs argsort on input buffer, storing indices in output buffer.
--- ## Attention
`struct Attention[dtype: DType, attention_config_t: AttentionConfig, output_type: DType, q_type: DType, k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, //, config: MHAConfig[dtype], group: Int, token_gen: Bool, sink: Bool, q_depth: Int = Int(config), cache_depth: Int = Int(config), output_depth: Int = Int(config)]` ## Fields * ​out\_reg\_buffer (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].OutputRegisterBufferType`): * ​p\_reg\_buffer (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].PRegisterBufferType`): * ​rowmax (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].RowMaxTensorType`): * ​rowsum (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].RowSumTensorType`): * ​gmem\_manager (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].GlobalMemoryManagerType`): * ​smem\_manager (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].SharedMemoryManagerType`): * ​q\_buffer (`Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].QRegisterBufferType`): * ​output\_ptr (`LegacyUnsafePointer[Scalar[output_type]]`): * ​batch\_idx (`Int`): * ​k (`k_t`): * ​v (`v_t`): * ​mask (`mask_t`): * ​mask\_block\_row (`UInt32`): * ​mask\_warp\_row (`UInt32`): * ​mask\_warp\_col (`UInt32`): * ​scale (`Float32`): * ​seq\_len (`Int`): * ​num\_keys (`Int`): * ​start\_pos (`Int`): * ​cache\_start\_pos (`Int`): * ​warp\_scratch\_tensor (`LayoutTensor[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type, Layout.row_major((2 * Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_warps_n)), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM)), MutAnyOrigin, address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True if True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else True if mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else mask_t.__del__is_trivial if v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial else v_t.__del__is_trivial if k_t.__del__is_trivial else k_t.__del__is_trivial` ### `accum_type` `comptime accum_type = get_accum_type[q_type]()` ### `BK` `comptime BK = config.block_k[dtype]()` ### `BM` `comptime BM = config.block_m[dtype]()` ### `BN` `comptime BN = config.block_n[dtype]()` ### `depth` `comptime depth = config.depth` ### `fragment_layout` `comptime fragment_layout = get_fragment_layout[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape]()` ### `fragment_layout_nested` `comptime fragment_layout_nested = get_nested_fragment_layout[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape]()` ### `GlobalMemoryManagerType` `comptime GlobalMemoryManagerType = GlobalMemoryManager[q_type, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].depth, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_heads, group, token_gen, q_depth, output_depth]` ### `k_group_size` `comptime k_group_size = (16 // (num_matrix_reg[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](0), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](2)]() * size_of[q_type]()))` ### `mma_shape` `comptime mma_shape = attention_config_t.get_mma_shape()()` ### `num_heads` `comptime num_heads = config.num_heads` ### `num_k_mmas2` `comptime num_k_mmas2 = ceildiv(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK, UInt((Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](2) * Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size)))` ### `num_m_mmas` `comptime num_m_mmas = ceildiv(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WM, UInt(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](0)))` ### `num_n_mmas` `comptime num_n_mmas = ceildiv(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN, UInt(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](1)))` ### `num_n_mmas_output` `comptime num_n_mmas_output = ceildiv((output_depth // Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_warps_n)), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape.__getitem__[3, DType.int64, Int](1))` ### `num_stages` `comptime num_stages = 2` ### `num_threads` `comptime num_threads = config.num_threads[dtype]()` ### `num_warps_m` `comptime num_warps_m = (Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM // Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WM)` ### `num_warps_n` `comptime num_warps_n = (Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN // Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN)` ### `output_frag_size` `comptime output_frag_size = Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].fragment_layout.size()` ### `OutputRegisterBufferType` `comptime OutputRegisterBufferType = OutputRegisterBuffer[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type, Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_m_mmas), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_n_mmas_output, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].output_frag_size]` ### `PRegisterBufferType` `comptime PRegisterBufferType = PRegisterBuffer[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type, q_type, Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WM), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_m_mmas), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_n_mmas), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].output_frag_size, (Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN != Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size, attention_config_t.double_buffer]` ### `QRegisterBufferType` `comptime QRegisterBufferType = QRegisterBuffer[q_type, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size, Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WM), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].WN), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK), q_depth, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].warp_layout]` ### `row_layout` `comptime row_layout = Layout.row_major(Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].num_m_mmas), Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].fragment_layout.shape[0].value())` ### `RowMaxTensorType` `comptime RowMaxTensorType = LayoutTensor[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].accum_type, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].row_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `RowSumTensorType` `comptime RowSumTensorType = Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].RowMaxTensorType` ### `SharedMemoryManagerType` `comptime SharedMemoryManagerType = SharedMemoryManager[attention_config_t.shared_kv, attention_config_t.full_kv, attention_config_t.depth_padded, attention_config_t.double_buffer, q_type, Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BM), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BN), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK), Int(Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].depth), token_gen]` ### `swap_a_b` `comptime swap_a_b = True` ### `use_exp2` `comptime use_exp2 = True` ### `warp_layout` `comptime warp_layout = get_warp_layout[Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape]()` ### `WM` `comptime WM = config.warp_m[dtype]()` ### `WN` `comptime WN = config.warp_n[dtype]()` ## Methods ### `__init__` `__init__(out self, attention_config: attention_config_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], q: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, mask: mask_t, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]], batch_idx: Int, scale: Float32, seq_len: Int, num_keys: Int, start_pos: Int, cache_start_pos: Int = 0)` ### `q_head_idx` `static q_head_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `q_tile_idx` `static q_tile_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `kv_head_idx` `static kv_head_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `zero_p_buffer` `zero_p_buffer(self)` ### `get_batch_idx` `get_batch_idx(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `get_tensor_core_mma_qk` `static get_tensor_core_mma_qk(out result: TiledTensorCore[get_accum_type[q_type](), q_type, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size, True])` **Returns:** [`TiledTensorCore`](/mojo/kernels/layout/tensor_core/TiledTensorCore) ### `get_tensor_core_mma_pv` `static get_tensor_core_mma_pv(out result: TiledTensorCore[get_accum_type[q_type](), q_type, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].mma_shape, Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].k_group_size])` **Returns:** [`TiledTensorCore`](/mojo/kernels/layout/tensor_core/TiledTensorCore) ### `mma_qk` `mma_qk[k_buffer_type: KVBuffer, //, prefetch_function: OptionalReg[fn() capturing -> None] = None, beg_iter: Int = 0, num_iters: Int = Int((Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].depth // Attention[config, group, token_gen, sink, q_depth, cache_depth, output_depth].BK)), prefetched_b_tile: Bool = False](mut self, mut k_buffer: k_buffer_type)` ### `mma_pv` `mma_pv[v_buffer_type: KVBuffer, //, prefetch_function: OptionalReg[fn() capturing -> None] = None, prefetched_b_tile: Bool = True](mut self, mut v_buffer: v_buffer_type)` ### `mask_status` `mask_status(self, kv_tile_start_row: UInt32) -> TileMaskStatus` **Returns:** [`TileMaskStatus`](/mojo/kernels/nn/mha_mask/TileMaskStatus) ### `mask_advance` `mask_advance(mut self)` ### `mask_skip_tile` `mask_skip_tile(self, status: TileMaskStatus) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `mask_skip_and_advance` `mask_skip_and_advance(mut self, kv_tile_start_row: UInt32) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `mask_apply` `mask_apply(mut self, kv_tile_start_row: UInt32, kv_tile_num_rows: UInt32, not_last_iter: Bool)` ### `online_softmax` `online_softmax(self)` ### `store_output` `store_output(self)` ### `copy_fragment_to_smem` `copy_fragment_to_smem(self)` ### `store_partition_info` `store_partition_info(self, num_partitions: Int, exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]])`
--- ## AttentionConfig
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `depth_padded` `comptime depth_padded` ### `double_buffer` `comptime double_buffer` ### `full_kv` `comptime full_kv` ### `shared_kv` `comptime shared_kv` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `q_head_idx` `static q_head_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `q_tile_idx` `static q_tile_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `kv_head_idx` `static kv_head_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `get_mma_shape` `static get_mma_shape() -> IndexList[3]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `get_q_offset` `static get_q_offset[q_depth: UInt]() -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_output_offset` `static get_output_offset[output_depth: UInt]() -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## attention (Attention)
## Structs * [​`Attention`](./Attention): ## Traits * [​`AttentionConfig`](./AttentionConfig):
--- ## KBufferConfig
`struct KBufferConfig[BN: Int, BK: Int, WN: Int]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`KVBufferConfig`](/mojo/kernels/nn/attention/gpu/amd/buffers/KVBufferConfig), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `btile_dim0` `comptime btile_dim0 = BN` ### `btile_dim1` `comptime btile_dim1 = BK` ### `iterator_axis` `comptime iterator_axis = 1` ### `wsize` `comptime wsize = KBufferConfig[BN, BK, WN].wtile_dim0` ### `wtile_dim0` `comptime wtile_dim0 = WN` ### `wtile_dim1` `comptime wtile_dim1 = BK` ## Methods ### `get_wtile_coord` `static get_wtile_coord() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## KVBuffer
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `mma_tile_layout` `comptime mma_tile_layout` ## Required methods ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType) ### `load_from_dram` `load_from_dram(mut self: _Self)` ### `get_mma_tile` `get_mma_tile(self: _Self) -> LayoutTensor[_Self._dtype, _Self.mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `copy_to_shared` `copy_to_shared[tile_id: Int = 0](self: _Self)` ### `load_from_shared` `load_from_shared[k_mma: Int](self: _Self)`
--- ## KVBufferConfig
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `btile_dim0` `comptime btile_dim0` ### `btile_dim1` `comptime btile_dim1` ### `iterator_axis` `comptime iterator_axis` ### `wsize` `comptime wsize` ### `wtile_dim0` `comptime wtile_dim0` ### `wtile_dim1` `comptime wtile_dim1` ## Required methods ### `get_wtile_coord` `static get_wtile_coord() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## KVBufferImpl
`struct KVBufferImpl[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, mut: Bool, dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, origin: Origin[mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, //, config: KVBufferConfig, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False]` ## Fields * ​load\_tile (`KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].LoadTileType`): * ​mma\_tile (`KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMATileType`): * ​smem\_iter (`KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].SharedIterType`): * ​bounds (`Int`): * ​load\_tile\_id (`Int`): * ​global\_iterator (`KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].GlobalTiledIteratorType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`KVBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/KVBuffer), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `base_layout` `comptime base_layout = Layout.row_major(config.btile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)` ### `GlobalTensorType` `comptime GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` ### `GlobalTiledIteratorType` `comptime GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, config.btile_dim0, config.btile_dim1]()[0], origin, address_space=address_space, axis=config.iterator_axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, config.btile_dim0, config.btile_dim1]()]` ### `LoadTileType` `comptime LoadTileType = LayoutTensor[dtype, Layout.row_major(((num_stages * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_mmas) * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_k_tiles), KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `MMA_K` `comptime MMA_K = shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_N` `comptime MMA_N = shape.__getitem__[3, DType.int64, Int](1)` ### `mma_tile_layout` `comptime mma_tile_layout = Layout.row_major(KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_mmas, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)` ### `MMATileType` `comptime MMATileType = LayoutTensor[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `num_k_tiles` `comptime num_k_tiles = ceildiv(BK, (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_K * group_size))` ### `num_mmas` `comptime num_mmas = ceildiv(config.wsize, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMA_N)` ### `num_repeats` `comptime num_repeats = (config.btile_dim1 // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)` ### `num_warps_n` `comptime num_warps_n = (BN // WN)` ### `SharedIterType` `comptime SharedIterType = LayoutTensorIter[dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]` ### `SharedTileType` `comptime SharedTileType = KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].SharedIterType.LayoutTensorType` ### `SharedWarpTileType` `comptime SharedWarpTileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_index_type(AddressSpace.SHARED), _get_index_type(AddressSpace.SHARED), False, align_of[dtype](), KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim1]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_index_type(AddressSpace.SHARED), linear_idx_type=_get_index_type(AddressSpace.SHARED), masked=_tile_is_masked[KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim0, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].wtile_dim1]()]` ### `simd_width` `comptime simd_width = simd_width_of[dtype]()` ### `smem_layout` `comptime smem_layout = blocked_product(KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].base_layout, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].tiler_layout, True) if (not token_gen._mlir_value) else Layout.row_major(config.btile_dim0, config.btile_dim1)` ### `thread_layout` `comptime thread_layout = Layout.row_major(((min(num_threads, ((config.btile_dim0 * config.btile_dim1) // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) * KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width) // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout.stride[0].value()), (KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].smem_layout.stride[0].value() // KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].simd_width)) if token_gen else Layout.row_major((num_threads // 4), 4)` ### `tiler_layout` `comptime tiler_layout = Layout.row_major(1, KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].num_repeats)` ### `wtile_dim0` `comptime wtile_dim0 = config.wtile_dim0` ### `wtile_dim1` `comptime wtile_dim1 = config.wtile_dim1` ## Methods ### `__init__` `__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_b_rows: OptionalReg[Int], shared_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin])` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType) ### `load_from_dram` `load_from_dram(mut self)` ### `get_mma_tile` `get_mma_tile(self) -> KVBufferImpl[config, tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen].MMATileType` **Returns:** `KVBufferImpl` ### `copy_to_shared` `copy_to_shared[tile_id: Int = 0](self)` ### `load_from_shared` `load_from_shared[k_mma: Int](self)`
--- ## OutputRegisterBuffer
`struct OutputRegisterBuffer[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int]` ## Fields * ​reg\_tile (`OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].RegisterTileType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`RegisterBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterBuffer), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `reg_dtype` `comptime reg_dtype = dtype` ### `reg_tile_layout` `comptime reg_tile_layout = Layout.row_major((num_n_mmas * num_m_mmas), output_frag_size)` ### `RegisterTileType` `comptime RegisterTileType = LayoutTensor[dtype, OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ## Methods ### `__init__` `__init__(out self)` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType) ### `vectorize` `vectorize(self) -> LayoutTensor[dtype, coalesce(LayoutTensor._compute_tile_layout[True, dtype, OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), 1, output_frag_size]()[1], True), MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=LayoutTensor._divide_tiles[True, dtype, OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), 1, output_frag_size]()[0], layout_int_type=_get_layout_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].reg_tile_layout, AddressSpace.LOCAL)]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `apply_softmax_denominator` `apply_softmax_denominator(self, rowsum: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` ### `zero` `zero(self)` ### `get_reg_tile` `get_reg_tile(self) -> OutputRegisterBuffer[dtype, num_m_mmas, num_n_mmas, output_frag_size].RegisterTileType` **Returns:** `OutputRegisterBuffer`
--- ## PRegisterBuffer
`struct PRegisterBuffer[accum_type_: DType, dtype: DType, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_m_mmas: Int, num_n_mmas: Int, output_frag_size: Int, shared_memory_backed: Bool, mma_shape: IndexList[3], k_group_size: Int, tr_load_enabled: Bool = False]` ## Fields * ​reg\_tile (`PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].RegisterTileType`): * ​shared\_memory\_tile (`PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].SharedMemoryTileType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`RegisterBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterBuffer), [`RegisterMMABuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterMMABuffer), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `mma_dtype` `comptime mma_dtype = dtype` ### `mma_tile_layout` `comptime mma_tile_layout = Layout.row_major(num_m_mmas, simd_width_of[dtype]())` ### `MMATileType` `comptime MMATileType = LayoutTensor[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].mma_dtype, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `reg_dtype` `comptime reg_dtype = accum_type_` ### `reg_tile_layout` `comptime reg_tile_layout = Layout.row_major((num_n_mmas * num_m_mmas), output_frag_size)` ### `RegisterTileType` `comptime RegisterTileType = LayoutTensor[accum_type_, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `shared_memory_layout` `comptime shared_memory_layout = blocked_product(Layout.row_major(BM, BK), Layout.row_major(1, (BN // BK)), False)` ### `SharedMemoryTileType` `comptime SharedMemoryTileType = LayoutTensor[dtype, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].shared_memory_layout, MutAnyOrigin, address_space=AddressSpace.SHARED]` ## Methods ### `__init__` `__init__(out self, shared_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin])` ### `get_mma_tile_reg` `get_mma_tile_reg[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].MMATileType` **Returns:** `PRegisterBuffer` ### `get_mma_tile_shared` `get_mma_tile_shared[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].MMATileType` **Returns:** `PRegisterBuffer` ### `get_mma_tile` `get_mma_tile[tile_idx: Int, k_idx: Int](self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].MMATileType` **Returns:** `PRegisterBuffer` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType) ### `vectorize` `vectorize(self) -> LayoutTensor[accum_type_, coalesce(LayoutTensor._compute_tile_layout[True, accum_type_, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, AddressSpace.LOCAL), False, align_of[accum_type_](), 1, output_frag_size]()[1], True), MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=LayoutTensor._divide_tiles[True, accum_type_, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, AddressSpace.LOCAL), False, align_of[accum_type_](), 1, output_frag_size]()[0], layout_int_type=_get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].reg_tile_layout, AddressSpace.LOCAL)]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `zero` `zero(self)` ### `get_reg_tile` `get_reg_tile(self) -> PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].RegisterTileType` **Returns:** `PRegisterBuffer` ### `get_shared_memory_tile` `get_shared_memory_tile(self, tile_idx: Int) -> LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].shared_memory_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].shared_memory_layout, AddressSpace.SHARED), _get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].shared_memory_layout, AddressSpace.SHARED), False, align_of[dtype](), BM, BK]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_layout_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].shared_memory_layout, AddressSpace.SHARED), linear_idx_type=_get_index_type(PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].shared_memory_layout, AddressSpace.SHARED), masked=_tile_is_masked[PRegisterBuffer[accum_type_, dtype, BM, BN, BK, WM, WN, num_m_mmas, num_n_mmas, output_frag_size, shared_memory_backed, mma_shape, k_group_size, tr_load_enabled].shared_memory_layout, BM, BK]()]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `copy_to_shared` `copy_to_shared(self)`
--- ## QRegisterBuffer
`struct QRegisterBuffer[dtype: DType, mma_shape: IndexList[3], k_group_size: Int, WM: Int, WN: Int, BN: Int, BK: Int, depth: Int, thread_layout: Layout]` ## Fields * ​reg\_tile (`QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].RegisterTileType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`RegisterBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterBuffer), [`RegisterMMABuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterMMABuffer), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `mma_dtype` `comptime mma_dtype = dtype` ### `MMA_K` `comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = mma_shape.__getitem__[3, DType.int64, Int](0)` ### `mma_tile_layout` `comptime mma_tile_layout = LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), False, align_of[dtype](), (LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0].shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), 0]()[0]` ### `MMATileType` `comptime MMATileType = LayoutTensor[dtype, LayoutTensor._compute_tile_layout[True, dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), _get_index_type(LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0], AddressSpace.LOCAL), False, align_of[dtype](), (LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout.shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), 0]()[0].shape[0].value() // QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), 0]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `num_k_tiles` `comptime num_k_tiles = ceildiv(BK, (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].MMA_K * k_group_size))` ### `num_mmas` `comptime num_mmas = ceildiv(WM, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].MMA_M)` ### `num_tiles` `comptime num_tiles = (depth // BK)` ### `reg_dtype` `comptime reg_dtype = dtype` ### `reg_tile_layout` `comptime reg_tile_layout = Layout.row_major(((QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles) * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width)` ### `RegisterTileType` `comptime RegisterTileType = LayoutTensor[dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `simd_width` `comptime simd_width = simd_width_of[dtype]()` ### `TiledIteratorType` `comptime TiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[True, dtype, QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), _get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), False, align_of[dtype](), (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL, axis=0, layout_int_type=_get_layout_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), linear_idx_type=_get_index_type(QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, AddressSpace.LOCAL), masked=_tile_is_masked[QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].reg_tile_layout, (QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_mmas * QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].num_k_tiles), QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].simd_width]()]` ## Methods ### `__init__` `__init__(out self, tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType) ### `get_iter` `get_iter(self) -> QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].TiledIteratorType` **Returns:** `QRegisterBuffer` ### `get_mma_tile` `get_mma_tile[tile_idx: Int, k_idx: Int](self) -> QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].MMATileType` **Returns:** `QRegisterBuffer` ### `get_reg_tile` `get_reg_tile(self) -> QRegisterBuffer[dtype, mma_shape, k_group_size, WM, WN, BN, BK, depth, thread_layout].RegisterTileType` **Returns:** `QRegisterBuffer` ### `zero` `zero(self)`
--- ## RegisterBuffer
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `reg_dtype` `comptime reg_dtype` ### `reg_tile_layout` `comptime reg_tile_layout` ## Required methods ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType) ### `zero` `zero(self: _Self)` ### `get_reg_tile` `get_reg_tile(self: _Self) -> LayoutTensor[_Self.reg_dtype, _Self.reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## RegisterMMABuffer
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`RegisterBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/RegisterBuffer), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `mma_dtype` `comptime mma_dtype` ### `mma_tile_layout` `comptime mma_tile_layout` ### `reg_dtype` `comptime reg_dtype` ### `reg_tile_layout` `comptime reg_tile_layout` ## Required methods ### `get_mma_tile` `get_mma_tile[tile_idx: Int, k_idx: Int](self: _Self) -> LayoutTensor[_Self.mma_dtype, _Self.mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType) ### `zero` `zero(self: _Self)` ### `get_reg_tile` `get_reg_tile(self: _Self) -> LayoutTensor[_Self.reg_dtype, _Self.reg_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## VBufferConfig
`struct VBufferConfig[BN: Int, BK: Int, WN: Int, depth: Int]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`KVBufferConfig`](/mojo/kernels/nn/attention/gpu/amd/buffers/KVBufferConfig), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `btile_dim0` `comptime btile_dim0 = BK` ### `btile_dim1` `comptime btile_dim1 = depth` ### `iterator_axis` `comptime iterator_axis = 0` ### `wsize` `comptime wsize = VBufferConfig[BN, BK, WN, depth].wtile_dim1` ### `wtile_dim0` `comptime wtile_dim0 = BK` ### `wtile_dim1` `comptime wtile_dim1 = (depth // (BN // WN))` ## Methods ### `get_wtile_coord` `static get_wtile_coord() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## VBufferTransposeLoads
`struct VBufferTransposeLoads[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, mut: Bool, dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, origin: Origin[mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]` ## Fields * ​load\_tile (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].LoadTileType`): * ​mma\_tile (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMATileType`): * ​smem\_iter (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].SharedIterType`): * ​global\_iterator (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].GlobalTiledIteratorType`): * ​global\_base\_tile (`VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].GlobalTensorType`): * ​current\_stage (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`KVBuffer`](/mojo/kernels/nn/attention/gpu/amd/buffers/KVBuffer), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `base_layout` `comptime base_layout = Layout.row_major(VBufferTransposeLoads.pad[out_type, in_type, shape, group_size, transpose_b, mut, dtype, layout, address_space, alignment, origin, masked, layout_int_type, linear_idx_type, tensor_core_mma, BN, BK, depth, num_threads, num_stages, depth](), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)` ### `depth_tile_size` `comptime depth_tile_size = min(depth, 128)` ### `GlobalTensorType` `comptime GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]` ### `GlobalTiledIteratorType` `comptime GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, BK, depth]()[0], origin, address_space=address_space, axis=0, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, BK, depth]()]` ### `load_width` `comptime load_width = 4 if (depth == 64) else VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width` ### `loads_per_thread_per_depth_tile` `comptime loads_per_thread_per_depth_tile = ((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size * BK) // (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width * num_threads))` ### `LoadTileType` `comptime LoadTileType = LayoutTensor[dtype, Layout.row_major(((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].loads_per_thread_per_depth_tile * (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size)) * num_stages), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `MMA_K` `comptime MMA_K = shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_M` `comptime MMA_M = shape.__getitem__[3, DType.int64, Int](0)` ### `mma_tile_layout` `comptime mma_tile_layout = Layout.row_major((depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)` ### `MMATileType` `comptime MMATileType = LayoutTensor[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `num_depth_tiles` `comptime num_depth_tiles = (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M)` ### `num_k_tiles` `comptime num_k_tiles = ceildiv(BK, (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_K * group_size))` ### `num_repeats` `comptime num_repeats = (BK // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)` ### `SharedIterType` `comptime SharedIterType = LayoutTensorIter[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]` ### `SharedTileType` `comptime SharedTileType = VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].SharedIterType.LayoutTensorType` ### `simd_width` `comptime simd_width = simd_width_of[dtype]()` ### `smem_layout` `comptime smem_layout = blocked_product(VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].base_layout, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].tiler_layout, True)` ### `tiler_layout` `comptime tiler_layout = Layout.row_major(1, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].num_repeats)` ## Methods ### `__init__` `__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], shared_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin])` ### `get_dtype` `static get_dtype() -> DType` **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType) ### `pad` `static pad[dim: Int]() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `load_from_dram` `load_from_dram(mut self)` ### `get_mma_tile` `get_mma_tile(self) -> VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMATileType` **Returns:** `VBufferTransposeLoads` ### `copy_to_shared` `copy_to_shared[tile_id: Int = 0](self)` ### `load_from_shared` `load_from_shared[k_mma: Int](self)`
--- ## buffers
## `comptime` values ### `KBuffer` `comptime KBuffer[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = KVBufferImpl[KBufferConfig[BN, BK, WN], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]` #### Parameters * ​out\_type ([`DType`](/stdlib/builtin/dtype/DType)): * ​in\_type ([`DType`](/stdlib/builtin/dtype/DType)): * ​shape ([`IndexList`](/stdlib/utils/index_/IndexList)): * ​group\_size ([`Int`](/stdlib/builtin/int/Int)): * ​transpose\_b ([`Bool`](/stdlib/builtin/bool/Bool)): * ​tensor\_core\_mma ([`TiledTensorCore`](/kernels/layout/tensor_core/TiledTensorCore)): * ​swizzle ([`OptionalReg`](/stdlib/collections/optional/OptionalReg)): * ​BN ([`Int`](/stdlib/builtin/int/Int)): * ​WN ([`Int`](/stdlib/builtin/int/Int)): * ​BK ([`Int`](/stdlib/builtin/int/Int)): * ​depth ([`Int`](/stdlib/builtin/int/Int)): * ​num\_threads ([`Int`](/stdlib/builtin/int/Int)): * ​num\_stages ([`Int`](/stdlib/builtin/int/Int)): * ​token\_gen ([`Bool`](/stdlib/builtin/bool/Bool)): ### `VBuffer` `comptime VBuffer[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1, token_gen: Bool = False] = KVBufferImpl[VBufferConfig[BN, BK, WN, depth], tensor_core_mma, swizzle, BN, WN, BK, depth, num_threads, num_stages, token_gen]` #### Parameters * ​out\_type ([`DType`](/stdlib/builtin/dtype/DType)): * ​in\_type ([`DType`](/stdlib/builtin/dtype/DType)): * ​shape ([`IndexList`](/stdlib/utils/index_/IndexList)): * ​group\_size ([`Int`](/stdlib/builtin/int/Int)): * ​transpose\_b ([`Bool`](/stdlib/builtin/bool/Bool)): * ​tensor\_core\_mma ([`TiledTensorCore`](/kernels/layout/tensor_core/TiledTensorCore)): * ​swizzle ([`OptionalReg`](/stdlib/collections/optional/OptionalReg)): * ​BN ([`Int`](/stdlib/builtin/int/Int)): * ​WN ([`Int`](/stdlib/builtin/int/Int)): * ​BK ([`Int`](/stdlib/builtin/int/Int)): * ​depth ([`Int`](/stdlib/builtin/int/Int)): * ​num\_threads ([`Int`](/stdlib/builtin/int/Int)): * ​num\_stages ([`Int`](/stdlib/builtin/int/Int)): * ​token\_gen ([`Bool`](/stdlib/builtin/bool/Bool)): ## Structs * [​`KBufferConfig`](./KBufferConfig): * [​`KVBufferImpl`](./KVBufferImpl): * [​`OutputRegisterBuffer`](./OutputRegisterBuffer): * [​`PRegisterBuffer`](./PRegisterBuffer): * [​`QRegisterBuffer`](./QRegisterBuffer): * [​`VBufferConfig`](./VBufferConfig): * [​`VBufferTransposeLoads`](./VBufferTransposeLoads): ## Traits * [​`KVBuffer`](./KVBuffer): * [​`KVBufferConfig`](./KVBufferConfig): * [​`RegisterBuffer`](./RegisterBuffer): * [​`RegisterMMABuffer`](./RegisterMMABuffer):
--- ## amd (Amd)
AMD GPU attention operations. ## Modules * [​`attention`](./attention/): * [​`buffers`](./buffers/): * [​`mha_gfx942`](./mha_gfx942/): * [​`mha_gfx950`](./mha_gfx950/): * [​`mla`](./mla/): * [​`mma`](./mma/): * [​`utils`](./utils/):
--- ## MHAAttentionConfig
`struct MHAAttentionConfig[dtype: DType, //, token_gen: Bool, config: MHAConfig[dtype], group: Int]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`AttentionConfig`](/mojo/kernels/nn/attention/gpu/amd/attention/AttentionConfig), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `depth_padded` `comptime depth_padded = False if (not token_gen._mlir_value) if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else True` ### `double_buffer` `comptime double_buffer = True if (not token_gen._mlir_value) if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else False` ### `full_kv` `comptime full_kv = True if (not token_gen._mlir_value) if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else False` ### `shared_kv` `comptime shared_kv = False if (not token_gen._mlir_value) if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else True` ### `USE_EXPERIMENTAL_CDNA4_MHA_KERNEL` `comptime USE_EXPERIMENTAL_CDNA4_MHA_KERNEL = token_gen.__invert__() if env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer() else env_get_bool["USE_EXPERIMENTAL_CDNA4_MHA_KERNEL", False]() if _cdna_4_or_newer() else _cdna_4_or_newer()` ## Methods ### `q_head_idx` `static q_head_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `q_tile_idx` `static q_tile_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `kv_head_idx` `static kv_head_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `get_mma_shape` `static get_mma_shape() -> IndexList[3]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `get_q_offset` `static get_q_offset[q_depth: UInt]() -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_output_offset` `static get_output_offset[output_depth: UInt]() -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## mha_gfx942
## Structs * [​`MHAAttentionConfig`](./MHAAttentionConfig):
--- ## KVBuffer (Mha_gfx950)
`struct KVBuffer[kv_t: MHAOperand, //, mma_shape: IndexList[3], k_group_size: Int, swizzle: OptionalReg[Swizzle], BN: Int, WN: Int, BK: Int, num_threads: Int, depth: Int, kv_num_heads: Int, transpose: Bool]` ## Fields * ​mma\_tile (`KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMATileType`): * ​smem\_iter (`KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].SharedIterType`): * ​kv\_cache\_iter (`KVCacheIterator[kv_t, BN, kv_num_heads, depth]`): * ​buffer\_idx (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True if True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if True if kv_t.__del__is_trivial else kv_t.__del__is_trivial else True if kv_t.__del__is_trivial else kv_t.__del__is_trivial` ### `base_layout` `comptime base_layout = Layout.row_major(BN, BK)` ### `MMA_K` `comptime MMA_K = mma_shape.__getitem__[3, DType.int64, Int](2)` ### `MMA_N` `comptime MMA_N = mma_shape.__getitem__[3, DType.int64, Int](1)` ### `MMATileType` `comptime MMATileType = LayoutTensor[kv_t.dtype, Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]` ### `num_k_mmas2` `comptime num_k_mmas2 = ceildiv(BK, Int.__init__[Int]((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_K * k_group_size)))` ### `num_mmas` `comptime num_mmas = ceildiv(WN if transpose else depth, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].MMA_N)` ### `num_repeats` `comptime num_repeats = (depth // BK)` ### `SharedIterType` `comptime SharedIterType = LayoutTensorIter[kv_t.dtype, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]` ### `SharedTileType` `comptime SharedTileType = KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].SharedIterType.LayoutTensorType` ### `SharedWarpTileType` `comptime SharedWarpTileType = LayoutTensor[kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, MutAnyOrigin, AddressSpace.SHARED, Layout(IntTuple(1), IntTuple(1)), _get_index_type(AddressSpace.SHARED), _get_index_type(AddressSpace.SHARED), False, align_of[kv_t.dtype](), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim0, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim1]()[0], MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=_get_index_type(AddressSpace.SHARED), linear_idx_type=_get_index_type(AddressSpace.SHARED), masked=_tile_is_masked[KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].smem_layout, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim0, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].wtile_dim1]()]` ### `simd_width` `comptime simd_width = simd_width_of[kv_t.dtype]()` ### `smem_layout` `comptime smem_layout = blocked_product(KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].base_layout, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].tiler_layout, False)` ### `tiler_layout` `comptime tiler_layout = Layout.row_major(1, KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_repeats)` ### `wtile_dim0` `comptime wtile_dim0 = WN` ### `wtile_dim1` `comptime wtile_dim1 = BK` ## Methods ### `__init__` `__init__(out self, k_cache: kv_t, batch_idx: UInt, head_idx: UInt, shared_ptr: UnsafePointer[Scalar[kv_t.dtype], origin, address_space=AddressSpace.SHARED], end: UInt)` ### `load_from_dram` `load_from_dram(mut self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `get_mma_tile` `get_mma_tile[k_mma_tile_idx: Int](self) -> LayoutTensor[kv_t.dtype, LayoutTensor._compute_tile_layout[True, kv_t.dtype, Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), MutAnyOrigin, AddressSpace.LOCAL, Layout(IntTuple(1), IntTuple(1)), _get_layout_type(Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), _get_index_type(Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width), AddressSpace.LOCAL), False, align_of[kv_t.dtype](), (Layout.row_major((KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_mmas * KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].simd_width).shape[0].value() // KVBuffer[mma_shape, k_group_size, swizzle, BN, WN, BK, num_threads, depth, kv_num_heads, transpose].num_k_mmas2), 0]()[0], MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `copy_to_shared` `copy_to_shared(self)` ### `load_from_shared` `load_from_shared(self, buffer: UInt, bk_tile: UInt)`
--- ## KVCacheIterator
`struct KVCacheIterator[cache_t: MHAOperand, tile_size: Int, kv_num_heads: Int, depth: Int]` ## Fields * ​cache (`cache_t`): * ​end (`Int`): * ​tile\_start\_row (`Int`): * ​batch\_idx (`Int`): * ​kv\_head\_idx (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True if True if True if True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if True if True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if True if cache_t.__del__is_trivial else cache_t.__del__is_trivial else True if cache_t.__del__is_trivial else cache_t.__del__is_trivial` ### `kv_gmem_layout` `comptime kv_gmem_layout = Layout(IntTuple(Int.__init__[Int](tile_size), Int.__init__[Int](depth)), IntTuple(Int.__init__[Int]((kv_num_heads * depth)), 1))` ## Methods ### `__init__` `__init__(out self, cache: cache_t, batch_idx: Int, kv_head_idx: Int, end: Int)` ### `next_unsafe` `next_unsafe(mut self) -> LayoutTensor[cache_t.dtype, KVCacheIterator[cache_t, tile_size, kv_num_heads, depth].kv_gmem_layout, MutAnyOrigin, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `increment` `increment(mut self)`
--- ## block_sync_lds
`block_sync_lds[*, lgkmcnt: UInt32 = 0]()` Synchronize LDS (local data share) with waitcnt barrier.
--- ## block_sync_lds_direct_load
`block_sync_lds_direct_load[*, vmcnt: UInt32 = 0]()` Synchronize LDS for direct load with waitcnt barrier.
--- ## copy_dram_to_sram_lds
`copy_dram_to_sram_lds[swizzle: OptionalReg[Swizzle] = OptionalReg[Swizzle]()](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## mha_gfx950
## Structs * [​`KVBuffer`](./KVBuffer): * [​`KVCacheIterator`](./KVCacheIterator): ## Functions * [​`block_sync_lds`](./block_sync_lds): Synchronize LDS (local data share) with waitcnt barrier. * [​`block_sync_lds_direct_load`](./block_sync_lds_direct_load): Synchronize LDS for direct load with waitcnt barrier. * [​`copy_dram_to_sram_lds`](./copy_dram_to_sram_lds): * [​`load_b`](./load_b): * [​`load_b_`](./load_b_):
--- ## load_b
`load_b[mma_shape: IndexList[3], swizzle: OptionalReg[Swizzle]](src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, Layout.row_major((layout.size() // (WARP_SIZE * 8)), 8), MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## load_b_
`load_b_[mma_shape: IndexList[3], swizzle: OptionalReg[Swizzle], k_tile_idx: Int](src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> SIMD[dtype, simd_width_of[dtype]()]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## MLAAttentionConfig
`struct MLAAttentionConfig[dtype: DType, //, token_gen: Bool, config: MHAConfig[dtype]]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`AttentionConfig`](/mojo/kernels/nn/attention/gpu/amd/attention/AttentionConfig), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `depth_padded` `comptime depth_padded = True` ### `double_buffer` `comptime double_buffer = False` ### `full_kv` `comptime full_kv = False` ### `shared_kv` `comptime shared_kv = True` ## Methods ### `q_head_idx` `static q_head_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `q_tile_idx` `static q_tile_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `kv_head_idx` `static kv_head_idx() -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `get_mma_shape` `static get_mma_shape() -> IndexList[3]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `get_q_offset` `static get_q_offset[q_depth: UInt]() -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_output_offset` `static get_output_offset[output_depth: UInt]() -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## mla
## Structs * [​`MLAAttentionConfig`](./MLAAttentionConfig):
--- ## mma (Mma)
## Functions * [​`mma`](./mma):
--- ## mma (3)
`mma[c_register_buffer_type: RegisterBuffer, a_register_buffer_type: RegisterMMABuffer, b_buffer_type: KVBuffer, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], BK: Int, prefetch_function: OptionalReg[fn() capturing -> None], swap_a_b: Bool = False, beg_iter: Int = 0, num_iters: Int = 1, prefetched_b_tile: Bool = False](c: c_register_buffer_type, mut a_tile: a_register_buffer_type, mut b_tile: b_buffer_type)`
--- ## GlobalMemoryManager
`struct GlobalMemoryManager[dtype: DType, BM: UInt32, BN: UInt32, BK: UInt32, depth: UInt32, num_heads: UInt32, group: UInt32, token_gen: Bool, q_depth: UInt32 = depth, output_depth: UInt32 = depth]` ## Fields * ​q\_offset (`UInt32`): * ​q\_runtime\_layout (`RuntimeLayout[GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].q_gmem_layout, element_type=DType.int32, linear_idx_type=DType.int32]`): * ​output\_offset (`UInt32`): * ​output\_runtime\_layout (`RuntimeLayout[GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].output_gmem_layout, element_type=DType.int32, linear_idx_type=DType.int32]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `kv_gmem_layout` `comptime kv_gmem_layout = Layout(IntTuple(Int.__init__[UInt32](BN), Int.__init__[UInt32](depth)), IntTuple(Int.__init__[UInt32]((GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].kv_num_heads * depth)), 1))` ### `kv_num_heads` `comptime kv_num_heads = (num_heads // group)` ### `output_gmem_layout` `comptime output_gmem_layout = Layout(IntTuple(Int.__init__[UInt32](BM), Int.__init__[UInt32](output_depth)), IntTuple(Int.__init__[UInt32]((num_heads * output_depth)), 1)) if (not token_gen._mlir_value) else Layout.row_major(Int.__init__[UInt32](BM), Int.__init__[UInt32](output_depth))` ### `q_gmem_layout` `comptime q_gmem_layout = Layout(IntTuple(Int.__init__[UInt32](BM), Int.__init__[UInt32](q_depth)), IntTuple(Int.__init__[UInt32]((num_heads * q_depth)), 1)) if (not token_gen._mlir_value) else Layout.row_major(Int.__init__[UInt32](BM), Int.__init__[UInt32](q_depth))` ## Methods ### `__init__` `__init__(out self, q_tile_idx: UInt32, kv_head_idx: UInt32, seq_len: Int, q_offset: UInt32, output_offset: UInt32)` ### `get_q_tensor` `get_q_tensor[qtype: DType](self, ptr: LegacyUnsafePointer[Scalar[qtype]]) -> LayoutTensor[qtype, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].q_gmem_layout, MutAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `get_output_tensor` `get_output_tensor[out_type: DType](self, ptr: LegacyUnsafePointer[Scalar[out_type]]) -> LayoutTensor[out_type, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].output_gmem_layout, MutAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `get_kv_tensor` `get_kv_tensor[kvtype: DType, //](self, ptr: LegacyUnsafePointer[Scalar[kvtype], address_space=address_space, mut=mut, origin=origin], kv_tile_num_rows: UInt32) -> LayoutTensor[kvtype, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].kv_gmem_layout, origin, address_space=address_space, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## SharedMemoryManager (Utils)
`struct SharedMemoryManager[shared_kv: Bool, full_kv: Bool, depth_padded: Bool, double_buffer: Bool, dtype: DType, BM: Int, BN: Int, BK: Int, depth: Int, token_gen: Bool]` ## Fields * ​p\_smem (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): * ​k\_smem (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): * ​v\_smem (`LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `accum_type` `comptime accum_type = get_accum_type[dtype]()` ### `alignment` `comptime alignment = align_of[SIMD[dtype, simd_width_of[dtype]()]]()` ### `k_smem_size` `comptime k_smem_size = ((BN * depth if full_kv else BK) * 2 if double_buffer else 1)` ### `p_smem_size` `comptime p_smem_size = (BM * BN) if token_gen else 0` ### `simd_width` `comptime simd_width = simd_width_of[dtype]()` ### `v_smem_size` `comptime v_smem_size = ((BN if full_kv else BK * pad[dtype, depth, depth]() if depth_padded else depth) * 2 if double_buffer else 1)` ## Methods ### `__init__` `__init__(out self)` ### `get_k_ptr` `get_k_ptr[_dtype: DType](self) -> LegacyUnsafePointer[Scalar[_dtype], address_space=AddressSpace.SHARED]` **Returns:** [`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer) ### `get_v_ptr` `get_v_ptr[_dtype: DType](self) -> LegacyUnsafePointer[Scalar[_dtype], address_space=AddressSpace.SHARED]` **Returns:** [`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer) ### `get_p_ptr` `get_p_ptr[_dtype: DType](self) -> LegacyUnsafePointer[Scalar[_dtype], address_space=AddressSpace.SHARED]` **Returns:** [`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer) ### `get_warp_scratch_ptr` `get_warp_scratch_ptr[_dtype: DType](self) -> LegacyUnsafePointer[Scalar[_dtype], address_space=AddressSpace.SHARED]` **Returns:** [`LegacyUnsafePointer`](/mojo/stdlib/memory/legacy_unsafe_pointer/LegacyUnsafePointer)
--- ## copy_local_to_dram2
`copy_local_to_dram2[dst_thread_layout: Layout, thread_scope: ThreadScope = ThreadScope.BLOCK](dst: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], src: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dst_base: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## get_fragment_layout
`get_fragment_layout[mma_shape: IndexList[3]]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## get_nested_fragment_layout
`get_nested_fragment_layout[mma_shape: IndexList[3]]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## get_warp_coords
`get_warp_coords[BN: Int, WN: Int]() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_warp_layout
`get_warp_layout[mma_shape: IndexList[3]]() -> Layout` **Returns:** [`Layout`](/mojo/kernels/layout/layout/Layout)
--- ## utils (Utils)
## `comptime` values ### `LocalLayoutTensor` `comptime LocalLayoutTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` #### Parameters * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ### `SharedLayoutTensor` `comptime SharedLayoutTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED]` #### Parameters * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Structs * [​`GlobalMemoryManager`](./GlobalMemoryManager): * [​`SharedMemoryManager`](./SharedMemoryManager): ## Functions * [​`copy_local_to_dram2`](./copy_local_to_dram2): * [​`get_fragment_layout`](./get_fragment_layout): * [​`get_nested_fragment_layout`](./get_nested_fragment_layout): * [​`get_warp_coords`](./get_warp_coords): * [​`get_warp_layout`](./get_warp_layout): * [​`pad`](./pad):
--- ## pad
`pad[dtype: DType, depth: Int, size: Int]() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## gpu (Gpu)
GPU attention operations. ## Packages * [​`amd`](./amd/): AMD GPU attention operations.
--- ## attention (3)
Attention operations. ## Packages * [​`gpu`](./gpu/): GPU attention operations.
--- ## cpu_bicubic_kernel
`cpu_bicubic_kernel(output_host: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_host: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Perform bicubic interpolation on a LayoutTensor of form NCHW. **Args:** * ​output\_host ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor with desired dimensions. * ​input\_host ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input tensor of shape \[B, C, H, W].
--- ## cubic_kernel
`cubic_kernel(x: Float32) -> Float32` Cubic interpolation kernel matching PyTorch/torchvision's BICUBIC filter. This uses the Catmull-Rom variant (Robidoux cubic) with a = -0.75, which is what PyTorch uses in get\_cubic\_upsample\_coefficients. ([Source](https://github.com/pytorch/pytorch/blob/59eb61b2d1e4b64debbefa036acd0d8c7d55f0a3/aten/src/ATen/native/UpSample.h#L410-L423)). This also matches OpenCV's [interpolateCubic](https://github.com/opencv/opencv/blob/cf2a3c8e7430cc92569dd7f114609f9377b12d9e/modules/imgproc/src/resize.cpp#L907-L915). **Args:** * ​x ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): Distance from the center point. **Returns:** [`Float32`](/mojo/stdlib/builtin/simd/#float32): Weight contribution based on the distance. `cubic_kernel(x: SIMD[dtype, size]) -> SIMD[dtype, size]` Cubic interpolation kernel matching PyTorch/torchvision's BICUBIC filter. This uses the Catmull-Rom variant (Robidoux cubic) with a = -0.75, which is what PyTorch uses in get\_cubic\_upsample\_coefficients. ([Source](https://github.com/pytorch/pytorch/blob/59eb61b2d1e4b64debbefa036acd0d8c7d55f0a3/aten/src/ATen/native/UpSample.h#L410-L423)). This also matches OpenCV's [interpolateCubic](https://github.com/opencv/opencv/blob/cf2a3c8e7430cc92569dd7f114609f9377b12d9e/modules/imgproc/src/resize.cpp#L907-L915). **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Distance from the center point. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Weight contribution based on the distance.
--- ## gpu_bicubic_kernel
`gpu_bicubic_kernel[dtype: DType, input_layout: Layout, output_layout: Layout, address_space: AddressSpace = AddressSpace.GENERIC](output: LayoutTensor[dtype, output_layout, MutAnyOrigin, address_space=address_space], input: LayoutTensor[dtype, input_layout, MutAnyOrigin, address_space=address_space])` Perform bicubic interpolation using GPU. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor with desired dimensions on the device. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input tensor of shape \[B, C, H, W] on the device.
--- ## bicubic
This module provides CPU and GPU implementations for bicubic interpolation. Bicubic interpolation is a 2D extension of cubic interpolation for resampling digital images. It uses the weighted average of the 4x4 neighborhood of pixels around the target location to compute the interpolated value. ## Functions * [​`cpu_bicubic_kernel`](./cpu_bicubic_kernel): Perform bicubic interpolation on a LayoutTensor of form NCHW. * [​`cubic_kernel`](./cubic_kernel): Cubic interpolation kernel matching PyTorch/torchvision's BICUBIC filter. * [​`gpu_bicubic_kernel`](./gpu_bicubic_kernel): Perform bicubic interpolation using GPU. * [​`map_output_to_input_coord`](./map_output_to_input_coord): Map output pixel coordinate to input coordinate using center alignment. This implements the standard coordinate mapping for image resizing: input\_coord = (output\_coord + 0.5) \* scale - 0.5 The +0.5 and -0.5 terms ensure pixel centers are aligned properly. Args: output\_coord: Output pixel coordinate. scale: Scale factor (input\_size / output\_size). Returns: Corresponding input coordinate as a float. * [​`resize_bicubic`](./resize_bicubic): Perform bicubic interpolation.
--- ## map_output_to_input_coord
`map_output_to_input_coord(output_coord: Int, scale: Float32) -> Float32` Map output pixel coordinate to input coordinate using center alignment. This implements the standard coordinate mapping for image resizing: input\_coord = (output\_coord + 0.5) \* scale - 0.5 The +0.5 and -0.5 terms ensure pixel centers are aligned properly. Args: output\_coord: Output pixel coordinate. scale: Scale factor (input\_size / output\_size). Returns: Corresponding input coordinate as a float. **Returns:** [`Float32`](/mojo/stdlib/builtin/simd/#float32)
--- ## resize_bicubic
`resize_bicubic[dtype: DType, //, target: StringSlice[StaticConstantOrigin]](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Perform bicubic interpolation. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor with desired dimensions on host or device. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input tensor of shape \[B, C, H, W] on host or device. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): Device context to enqueue GPU kernels on.
--- ## broadcast
`broadcast[dtype: DType](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` For each axis of `input`, if the dimension is 1, duplicate the data at each index of the corresponding axis in `output`, otherwise copy over the entire axis to the corresponding axis in `output`. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer.
--- ## broadcast_impl
`broadcast_impl[dtype: DType](axis: Int, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_prev_axis_stride: Int, output_prev_axis_stride: Int, input_offset: Int, output_offset: Int, rightmost_broadcast_axis: Int)` For each axis of `input` ∈ \[axis, rank), if the dimension is 1, duplicate the data at each index of the corresponding axis in `output`, otherwise copy over the entire axis to the corresponding axis in `output`. **Args:** * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis value. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer. * ​input\_prev\_axis\_stride ([`Int`](/mojo/stdlib/builtin/int/Int)): The stride at axis `axis - 1` for input. * ​output\_prev\_axis\_stride ([`Int`](/mojo/stdlib/builtin/int/Int)): The stride at axis `axis - 1` for output. * ​input\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset at which we start copying data from. * ​output\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset at which we start copying data to. * ​rightmost\_broadcast\_axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The largest axis at which we need to duplicate `input` data.
--- ## broadcast (Broadcast)
## Functions * [​`broadcast`](./broadcast): For each axis of `input`, if the dimension is 1, duplicate the data at each index of the corresponding axis in `output`, otherwise copy over the entire axis to the corresponding axis in `output`. * [​`broadcast_impl`](./broadcast_impl): For each axis of `input` ∈ \[axis, rank), if the dimension is 1, duplicate the data at each index of the corresponding axis in `output`, otherwise copy over the entire axis to the corresponding axis in `output`.
--- ## concat (Concat)
`concat[output_layout: Layout, inputs_layout: Layout, //, dtype: DType, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = "cpu", epilogue_fn: OptionalReg[fn[c_type: DType, rank: Int, width: Int = 1, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None] = None](output: LayoutTensor[dtype, output_layout, origin], axis: Int, inputs: StaticTuple[LayoutTensor[dtype, inputs_layout, MutAnyOrigin], size], context: DeviceContextPtr = DeviceContextPtr())`
--- ## concat_shape
`concat_shape[inputs_layout: Layout, //, input_type: DType, single_thread_blocking_override: Bool](input_bufs: List[LayoutTensor[input_type, inputs_layout, MutAnyOrigin]], axis: Int) -> IndexList[inputs_layout.rank()]` Compute the output shape of a `pad` operation, and assert the inputs are compatible. **Parameters:** * ​inputs\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Input layout of the input tensor. * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input\_bufs ([`List`](/mojo/stdlib/collections/list/List)): The input tensors list. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## fused_concat
`fused_concat[output_layout: Layout, //, dtype: DType, rank: Int, single_thread_blocking_override: Bool, input_fn: fn[input_index: Int, width: Int, _rank: Int](IndexList[_rank]) capturing -> SIMD[dtype, width], output_0_fn: elementwise_epilogue_type, target: StringSlice[StaticConstantOrigin] = "cpu"](axis: Int, input_shapes: StaticTuple[IndexList[rank], size], output: LayoutTensor[dtype, output_layout, origin], ctx: DeviceContextPtr)`
--- ## concat (3)
## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[c_type: DType, rank: Int, width: Int = 1, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None` ## Functions * [​`concat`](./concat): * [​`concat_shape`](./concat_shape): Compute the output shape of a `pad` operation, and assert the inputs are compatible. * [​`fused_concat`](./fused_concat): * [​`memcpy_or_fuse`](./memcpy_or_fuse):
--- ## memcpy_or_fuse
`memcpy_or_fuse[rank: Int, dtype: DType, epilogue_fn: OptionalReg[fn[c_type: DType, rank: Int, width: Int = 1, *, alignment: Int = 1](IndexList[rank], SIMD[c_type, width]) capturing -> None]](dest_data: LegacyUnsafePointer[Int8], out_byte_offset: Int, src_data: LegacyUnsafePointer[Int8], n: Int, out_shape: IndexList[rank, element_type=element_type])`
--- ## ConvDirectNHWC
`struct ConvDirectNHWC[input_mut: Bool, filter_mut: Bool, conv_attr_rank: Int, //, input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_origin: Origin[input_mut], filter_origin: Origin[filter_mut], output_origin: MutOrigin, input_type: DType, filter_type: DType, output_type: DType, filter_packed: Bool, conv_attr: ConvInfoStatic[conv_attr_rank], elementwise_epilogue: OptionalReg[fn[rank: Int](coords: IndexList[rank], f_size: Int) capturing -> None] = None]` Implement the outer loops for direct convolution. Collapse N, HO, WO into one dimension n\_ho\_wo. Tile n\_ho\_wo, C, and F. The tile factor for C and F are chosen by a heuristic prioritizing C. n\_ho\_wo is tiled by micro kernel's height. If n\_ho\_wo is large enough to spill LLC, we may need to tile n\_ho\_wo as the outer most loop with a factor fit in LLC. Assume F is divisible at least by simd\_size. ## Fields * ​output (`LayoutTensor[output_type, output_layout, output_origin]`): * ​input (`LayoutTensor[input_type, input_layout, input_origin]`): * ​filter (`LayoutTensor[filter_type, filter_layout, filter_origin]`): * ​conv\_shape (`ConvShape[conv_attr_rank]`): * ​partition (`ConvPartition`): * ​cf\_tile\_size (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `packed_and_fully_static` `comptime packed_and_fully_static = filter_packed if filter_layout.shape.all_known() if output_layout.shape.all_known[1, output_layout.rank()]() if input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else output_layout.shape.all_known[1, output_layout.rank()]() if input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else filter_layout.shape.all_known() if output_layout.shape.all_known[1, output_layout.rank()]() if input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else output_layout.shape.all_known[1, output_layout.rank()]() if input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]() else input_layout.shape.all_known[1, input_layout.rank()]() if conv_attr.all_known[conv_attr_rank]() else conv_attr.all_known[conv_attr_rank]()` ## Methods ### `run` `static run(output: LayoutTensor[output_type, output_layout, output_origin], input: LayoutTensor[input_type, input_layout, input_origin], filter: LayoutTensor[filter_type, filter_layout, filter_origin], conv_shape: ConvShape[conv_attr_rank])` ### `is_new_c_accum` `is_new_c_accum(self, c_idx: Int) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `update_output_tile_no_padding` `update_output_tile_no_padding[micro_kernel_height: Int, micro_kernel_width: Int, c_fully_cached: Bool, has_residual: Bool, last_c_tile: Bool](self, n: Int, f_tile_offset: Int, f_tile_size: Int, c_tile_offset: Int, c_tile_size: Int, output_flat_coord: Int)` ### `output_space_flat_loop` `output_space_flat_loop[micro_kernel_f_size: Int, has_residual: Bool, last_c_tile: Bool](self, n: Int, f_tile_offset: Int, f_tile_size: Int, c_tile_offset: Int, c_tile_size: Int)` ### `output_space_loop` `output_space_loop[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool](self, n: Int, f_tile_offset: Int, f_tile_size: Int, c_tile_offset: Int, c_tile_size: Int)` ### `output_space_loop_1d` `output_space_loop_1d[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](self, output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], n: Int, first_c_tile_in_group: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, left_pad_impact_end: Int, right_pad_impact_start: Int)` ### `output_space_loop_2d` `output_space_loop_2d[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](self, output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], n: Int, first_c_tile_in_group: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, left_pad_impact_end: Int, right_pad_impact_start: Int)` ### `output_space_loop_3d` `output_space_loop_3d[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](self, output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], n: Int, first_c_tile_in_group: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, left_pad_impact_end: Int, right_pad_impact_start: Int)`
--- ## CuDNNConvMeta
`@register_passable` `struct CuDNNConvMeta` ## Fields * ​ptr\_handle (`LegacyUnsafePointer[cudnnContext]`): * ​ptr\_input\_desc (`LegacyUnsafePointer[cudnnTensorStruct]`): * ​ptr\_filter\_desc (`LegacyUnsafePointer[cudnnFilterStruct]`): * ​ptr\_conv\_desc (`LegacyUnsafePointer[cudnnConvolutionStruct]`): * ​ptr\_output\_desc (`LegacyUnsafePointer[cudnnTensorStruct]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self)` ### `__del__` `__del__(deinit self)`
--- ## Naive2dConvolution
`struct Naive2dConvolution[output_type: DType, input_type: DType, filter_type: DType]` Struct wrapper for naive 2d convolution implementation. ## Fields * ​output (`LegacyUnsafePointer[Scalar[output_type]]`): * ​input (`LegacyUnsafePointer[Scalar[input_type]]`): * ​filter (`LegacyUnsafePointer[Scalar[filter_type]]`): * ​pad\_d (`IndexList[2]`): * ​pad\_h (`IndexList[2]`): * ​pad\_w (`IndexList[2]`): * ​stride (`IndexList[3]`): * ​dilation (`IndexList[3]`): * ​num\_groups (`Int`): * ​output\_shape (`IndexList[5]`): * ​input\_shape (`IndexList[5]`): * ​filter\_shape (`IndexList[5]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, output: LegacyUnsafePointer[Scalar[output_type]], input: LegacyUnsafePointer[Scalar[input_type]], filter: LegacyUnsafePointer[Scalar[filter_type]], output_shape: IndexList[5], input_shape: IndexList[5], filter_shape: IndexList[5], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2], stride: IndexList[3], dilation: IndexList[3], num_groups: Int)` ### `run` `static run(output: LegacyUnsafePointer[Scalar[output_type]], input: LegacyUnsafePointer[Scalar[input_type]], filter: LegacyUnsafePointer[Scalar[filter_type]], output_shape: IndexList[5], input_shape: IndexList[5], filter_shape: IndexList[5], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2], stride: IndexList[3], dilation: IndexList[3], num_groups: Int)`
--- ## accumulate_wo_tile_1d
`accumulate_wo_tile_1d[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, partial_load_filter: Bool, effected_by_padding: Bool, input_dt: DType, filter_dt: DType](c_tile_size: Int, S: Int, mut acc: _Accumulator[dtype, num_rows, num_cols, simd_width, row_start, row_stop], input: LegacyUnsafePointer[Scalar[input_dt]], input_stride: Int, input_stride_to_nbr: Int, filter: LegacyUnsafePointer[Scalar[filter_dt]], filter_stride: Int, filter_stride_to_nbr: Int, partial_load_filter_size: Int, w: Int, W: Int, dilation: Int)` Update one row in the output for a given (c, f) tile. **Parameters:** * ​micro\_kernel\_height ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of input points in register tiling. * ​micro\_kernel\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of SIMD resgiters assigned to F. * ​simd\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements in a SIMD register. * ​partial\_load\_filter ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether using partial load for filter. * ​effected\_by\_padding ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the tile is effected by padding. * ​input\_dt ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType of input. * ​filter\_dt ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType of filter. **Args:** * ​c\_tile\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Tile size in input channel. * ​S ([`Int`](/mojo/stdlib/builtin/int/Int)): Filter window width. * ​acc ([`_Accumulator`](/mojo/kernels/linalg/accumulate/_Accumulator)): Pointer to register tile accumulator. * ​input (`LegacyUnsafePointer`): Pointer to the first input point in WO tile. * ​input\_stride ([`Int`](/mojo/stdlib/builtin/int/Int)): Stride between two input points, i.e., C w/ NHWC layout. * ​input\_stride\_to\_nbr ([`Int`](/mojo/stdlib/builtin/int/Int)): Stride between an input point and its neighbor. * ​filter (`LegacyUnsafePointer`): Pointer to the first coef in the filter window. * ​filter\_stride ([`Int`](/mojo/stdlib/builtin/int/Int)): Stride between two segments of size `micro_kernel_width * simd_size`. * ​filter\_stride\_to\_nbr ([`Int`](/mojo/stdlib/builtin/int/Int)): Stride between between two neighbor coefs, i.e., CF w/ RSCF layout. * ​partial\_load\_filter\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of partial load for filter. * ​w ([`Int`](/mojo/stdlib/builtin/int/Int)): Coordinate in an input row. * ​W ([`Int`](/mojo/stdlib/builtin/int/Int)): Input width. * ​dilation ([`Int`](/mojo/stdlib/builtin/int/Int)): Convolution dilation.
--- ## accumulate_wo_tile_2d
`accumulate_wo_tile_2d[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, partial_load_filter: Bool, effected_by_padding: Bool, input_dt: DType, filter_dt: DType](c_tile_size: Int, RS: IndexList[2], mut acc: _Accumulator[dtype, num_rows, num_cols, simd_width, row_start, row_stop], input: LegacyUnsafePointer[Scalar[input_dt]], input_stride: Int, input_stride_to_nbr: IndexList[2], filter: LegacyUnsafePointer[Scalar[filter_dt]], filter_stride: Int, filter_stride_to_nbr: IndexList[2], partial_load_filter_size: Int, hw: IndexList[2], HW: IndexList[2], dilation: IndexList[2])`
--- ## accumulate_wo_tile_3d
`accumulate_wo_tile_3d[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, partial_load_filter: Bool, effected_by_padding: Bool, input_dt: DType, filter_dt: DType](c_tile_size: Int, QRS: IndexList[3], mut acc: _Accumulator[dtype, num_rows, num_cols, simd_width, row_start, row_stop], input: LegacyUnsafePointer[Scalar[input_dt]], input_stride: Int, input_stride_to_nbr: IndexList[3], filter: LegacyUnsafePointer[Scalar[filter_dt]], filter_stride: Int, filter_stride_to_nbr: IndexList[3], partial_load_filter_size: Int, dhw: IndexList[3], DHW: IndexList[3], dilation: IndexList[3])`
--- ## check_cudnn_error
`check_cudnn_error(stat: cudnnStatus_t)`
--- ## conv1d_update_wo_tile
`conv1d_update_wo_tile[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, filter_packed: Bool, effected_by_padding: Bool, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType, elementwise_epilogue: OptionalReg[fn[rank: Int](coords: IndexList[rank], f_size: Int) capturing -> None] = None](output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], first_c_tile: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, conv_shape: ConvShape[rank], n: Int, wo: Int)`
--- ## conv2d_gpu_naive_nhwc_rscf
`conv2d_gpu_naive_nhwc_rscf[input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, block_size: Int, maybe_epilogue_func: OptionalReg[fn[dtype: DType, rank: Int, width: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None]](input: LayoutTensor[input_type, input_layout, MutAnyOrigin], filter: LayoutTensor[filter_type, filter_layout, MutAnyOrigin], output: LayoutTensor[output_type, output_layout, MutAnyOrigin], stride: IndexList[2], dilation: IndexList[2], padding: IndexList[2])`
--- ## conv2d_update_wo_tile
`conv2d_update_wo_tile[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, filter_packed: Bool, effected_by_padding: Bool, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType, elementwise_epilogue: OptionalReg[fn[rank: Int](coords: IndexList[rank], f_size: Int) capturing -> None] = None](output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], first_c_tile: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, conv_shape: ConvShape[2], n: Int, howo: IndexList[2])`
--- ## conv3d_gpu_naive_ndhwc_qrscf
`conv3d_gpu_naive_ndhwc_qrscf[input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, block_size: Int, maybe_epilogue_func: OptionalReg[fn[dtype: DType, rank: Int, width: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None]](input: LayoutTensor[input_type, input_layout, MutAnyOrigin], filter: LayoutTensor[filter_type, filter_layout, MutAnyOrigin], output: LayoutTensor[output_type, output_layout, MutAnyOrigin], stride: IndexList[3], dilation: IndexList[3], padding: IndexList[3])`
--- ## conv3d_update_wo_tile
`conv3d_update_wo_tile[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, filter_packed: Bool, effected_by_padding: Bool, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType, elementwise_epilogue: OptionalReg[fn[rank: Int](coords: IndexList[rank], f_size: Int) capturing -> None] = None](output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], first_c_tile: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, conv_shape: ConvShape[3], n: Int, dohowo: IndexList[3])`
--- ## conv_cudnn
`conv_cudnn[input_type: DType, filter_type: DType, output_type: DType](input: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[filter_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], stride: IndexList[2], dilation: IndexList[2], padding: IndexList[2], num_groups: Int, ctx: DeviceContext)`
--- ## conv_gpu
`conv_gpu[conv_rank: Int, //, input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, maybe_epilogue_func: OptionalReg[fn[dtype: DType, rank: Int, width: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None] = None, filter_is_fcrs: Bool = False](input: LayoutTensor[input_type, input_layout, MutAnyOrigin], filter: LayoutTensor[filter_type, filter_layout, MutAnyOrigin], output: LayoutTensor[output_type, output_layout, MutAnyOrigin], stride: IndexList[conv_rank], dilation: IndexList[conv_rank], padding: IndexList[(2 * conv_rank)], num_groups: Int, ctx: DeviceContext)`
--- ## conv_nhwc_direct
`conv_nhwc_direct[conv_info_rank: Int, //, input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, filter_packed: Bool, conv_info_static: ConvInfoStatic[conv_info_rank], lambdas_have_fusion: Bool, elementwise_lambda: elementwise_simd_epilogue_type](input: LayoutTensor[input_type, input_layout, origin], filter: LayoutTensor[filter_type, filter_layout, origin], output: LayoutTensor[output_type, output_layout, origin], stride: IndexList[conv_info_rank], dilation: IndexList[conv_info_rank], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2], num_groups: Int)`
--- ## conv_shape
`conv_shape[input_type: DType, filter_type: DType, strides_type: DType, dilations_type: DType, paddings_type: DType, single_thread_blocking_override: Bool](input_buf: LayoutTensor[input_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter_buf: LayoutTensor[filter_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides_buf: LayoutTensor[strides_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations_buf: LayoutTensor[dilations_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings_buf: LayoutTensor[paddings_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups_scalar: Scalar[dtype]) -> IndexList[LayoutTensor[input_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a `conv` operation, and assert the inputs are compatible. **Parameters:** * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​filter\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the filter tensor. * ​strides\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the strides tensor. * ​dilations\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the dilations tensor. * ​paddings\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the paddings tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run ssynchronouslysing a single thread. **Args:** * ​input\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​filter\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The filter tensor. * ​strides\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The strides tensor. * ​dilations\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The dilations tensor. * ​paddings\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The paddings tensor. * ​num\_groups\_scalar ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The num\_groups scalar. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## get_cudnn_dtype
`get_cudnn_dtype[dtype: DType]() -> cudnnDataType_t` Map Mojo DType to cuDNN data type. Support only floating point dtypes for now. **Returns:** `cudnnDataType_t`
--- ## conv
## Structs * [​`ConvDirectNHWC`](./ConvDirectNHWC): Implement the outer loops for direct convolution. Collapse N, HO, WO into one dimension n\_ho\_wo. Tile n\_ho\_wo, C, and F. The tile factor for C and F are chosen by a heuristic prioritizing C. n\_ho\_wo is tiled by micro kernel's height. * [​`CuDNNConvMeta`](./CuDNNConvMeta): * [​`Naive2dConvolution`](./Naive2dConvolution): Struct wrapper for naive 2d convolution implementation. ## Functions * [​`accumulate_wo_tile_1d`](./accumulate_wo_tile_1d): Update one row in the output for a given (c, f) tile. * [​`accumulate_wo_tile_2d`](./accumulate_wo_tile_2d): * [​`accumulate_wo_tile_3d`](./accumulate_wo_tile_3d): * [​`check_cudnn_error`](./check_cudnn_error): * [​`conv1d_update_wo_tile`](./conv1d_update_wo_tile): * [​`conv2d_gpu_naive_nhwc_rscf`](./conv2d_gpu_naive_nhwc_rscf): * [​`conv2d_update_wo_tile`](./conv2d_update_wo_tile): * [​`conv3d_gpu_naive_ndhwc_qrscf`](./conv3d_gpu_naive_ndhwc_qrscf): * [​`conv3d_update_wo_tile`](./conv3d_update_wo_tile): * [​`conv_cudnn`](./conv_cudnn): * [​`conv_gpu`](./conv_gpu): * [​`conv_nhwc_direct`](./conv_nhwc_direct): * [​`conv_shape`](./conv_shape): Compute the output shape of a `conv` operation, and assert the inputs are compatible. * [​`get_cudnn_dtype`](./get_cudnn_dtype): Map Mojo DType to cuDNN data type. * [​`pack_conv_filter_shape`](./pack_conv_filter_shape): Compute the output shape of convolution filter packing. * [​`pack_filter`](./pack_filter): This packs the filter form RSCF to FRSCf. Use the default micro kernel size for dynamic shapes. * [​`pack_filter_shape`](./pack_filter_shape): Compute the shape of packed filter. The packed layout is FRSCf. shape\_ref should be allocated with size 5 outside this kernel. * [​`pack_filter_shape_impl`](./pack_filter_shape_impl): Compute the shape of packed filter. The packed layout is FRSCf. shape\_ref should be allocated with size 5 outside this kernel.
--- ## pack_conv_filter_shape
`pack_conv_filter_shape[single_thread_blocking_override: Bool](filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int) -> IndexList[(LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank + 1)]` Compute the output shape of convolution filter packing. **Parameters:** * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The filter to be packed. * ​num\_groups ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of groups in the convolution. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## pack_filter
`pack_filter(filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], packed_filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int)` This packs the filter form RSCF to FRSCf. Use the default micro kernel size for dynamic shapes. `pack_filter[simd_size: Int, micro_kernel_f_size: Int](filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], packed_filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int)` This packs the filter form RSCF to FRSCf. F is first broken down to segments of size micro\_kernel\_f\_size, then the remainder is further divided by simd\_size. The last residual elements if any is padded with zero to fill simd\_size. **Parameters:** * ​simd\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Can differ from the simd size of the input type. * ​micro\_kernel\_f\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the last dimension in FRSCf, which is equals the size of the micro kernel's F dimension. **Args:** * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Filter in RSCF layout (if 2D). * ​packed\_filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Packed filter in FRSCf layout (if 2D). F - the index of continuous segments in micro kernel. R, S, C - original R, S, C. f - the index within a continuous segments. * ​num\_groups ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of groups in the convolution.
--- ## pack_filter_shape
`pack_filter_shape[filter_type: DType, input_shape: DimList, filter_shape: DimList, output_shape: DimList, strides: DimList, dilations: DimList, paddings: DimList, num_groups: Int, single_thread_blocking_override: Bool](filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[(LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank + 1)]` Compute the shape of packed filter. The packed layout is FRSCf. shape\_ref should be allocated with size 5 outside this kernel. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## pack_filter_shape_impl
`pack_filter_shape_impl[filter_type: DType](Q: Int, R: Int, S: Int, C: Int, F: Int, num_groups: Int) -> IndexList[6]` Compute the shape of packed filter. The packed layout is FRSCf. shape\_ref should be allocated with size 5 outside this kernel. **Args:** * ​Q ([`Int`](/mojo/stdlib/builtin/int/Int)): Original Q filter dimension. * ​R ([`Int`](/mojo/stdlib/builtin/int/Int)): Original R filter dimension. * ​S ([`Int`](/mojo/stdlib/builtin/int/Int)): Original S filter dimension. * ​C ([`Int`](/mojo/stdlib/builtin/int/Int)): Original C filter dimension. * ​F ([`Int`](/mojo/stdlib/builtin/int/Int)): Original F filter dimension. * ​num\_groups ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of groups in the convolution. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## ConvTransposedPacked
`struct ConvTransposedPacked[input_mut: Bool, input_element_layout: Layout, input_layout_int_type: DType, input_linear_idx_type: DType, input_masked: Bool, input_alignment: Int, filter_mut: Bool, filter_element_layout: Layout, filter_layout_int_type: DType, filter_linear_idx_type: DType, filter_masked: Bool, filter_alignment: Int, output_element_layout: Layout, output_layout_int_type: DType, output_linear_idx_type: DType, output_masked: Bool, output_alignment: Int, //, input_origin: Origin[input_mut], filter_origin: Origin[filter_mut], output_origin: MutOrigin, input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, conv_attr: ConvInfoStatic[(input_layout.rank() - 2)], elementwise_epilogue: OptionalReg[fn[rank: Int](coords: IndexList[rank], f_size: Int) capturing -> None] = None]` ## Fields * ​output (`LayoutTensor[output_type, output_layout, output_origin, element_layout=output_element_layout, layout_int_type=output_layout_int_type, linear_idx_type=output_linear_idx_type, masked=output_masked, alignment=output_alignment]`): * ​input (`LayoutTensor[input_type, input_layout, input_origin, element_layout=input_element_layout, layout_int_type=input_layout_int_type, linear_idx_type=input_linear_idx_type, masked=input_masked, alignment=input_alignment]`): * ​filter (`LayoutTensor[filter_type, filter_layout, filter_origin, element_layout=filter_element_layout, layout_int_type=filter_layout_int_type, linear_idx_type=filter_linear_idx_type, masked=filter_masked, alignment=filter_alignment]`): * ​conv\_shape (`ConvShape[(input_layout.rank() - 2)]`): * ​partition (`ConvPartition`): * ​cf\_tile\_size (`IndexList[2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `run` `static run(output: LayoutTensor[output_type, output_layout, output_origin, element_layout=output_element_layout, layout_int_type=output_layout_int_type, linear_idx_type=output_linear_idx_type, masked=output_masked, alignment=output_alignment], input: LayoutTensor[input_type, input_layout, input_origin, element_layout=input_element_layout, layout_int_type=input_layout_int_type, linear_idx_type=input_linear_idx_type, masked=input_masked, alignment=input_alignment], filter: LayoutTensor[filter_type, filter_layout, filter_origin, element_layout=filter_element_layout, layout_int_type=filter_layout_int_type, linear_idx_type=filter_linear_idx_type, masked=filter_masked, alignment=filter_alignment], conv_shape: ConvShape[(input_layout.rank() - 2)])` ### `input_space_loop` `input_space_loop[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool](self, n: Int, f_tile_offset: Int, f_tile_size: Int, c_tile_offset: Int, c_tile_size: Int)` ### `input_space_loop_2d` `input_space_loop_2d[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](self, output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], n: Int, first_c_tile_in_group: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, left_pad_impact_end: Int, right_pad_impact_start: Int)` ### `input_space_loop_3d` `input_space_loop_3d[micro_kernel_height: Int, micro_kernel_width: Int, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](self, output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], n: Int, first_c_tile_in_group: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, left_pad_impact_end: Int, right_pad_impact_start: Int)` ### `apply_epilogue` `apply_epilogue(self, n: Int, g: Int)`
--- ## accumulate_wo_tile
`accumulate_wo_tile[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, partial_load: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](c_tile_size: Int, output: LegacyUnsafePointer[Scalar[output_dt]], output_stride: Int, input: LegacyUnsafePointer[Scalar[input_dt]], input_stride: Int, filter: LegacyUnsafePointer[Scalar[filter_dt]], filter_stride: Int, partial_load_size: Int)`
--- ## conv_transpose_naive
`conv_transpose_naive[dtype: DType](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], stride: IndexList[3], dilation: IndexList[3], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2])` Implements the ConvTranspose operator from the MO spec. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input, output, and kernel tensors. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output data tensor that contains the result of the convolution. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input data tensor from previous layer, with size of (N x H x W x C), where N is the batch size, C is the number of channels, and H and W are the height and width. * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The weight (kernel) tensor, with size of (kH x kW x M/groups x C), where C is the number of channels, kH and kW are the height and width of the kernel, and M is the number of feature maps. * ​stride ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Stride along each spatial axis. * ​dilation ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Dilation value along each spatial axis of the filter. * ​pad\_d ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Padding in depth dimension. * ​pad\_h ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Padding in height dimension. * ​pad\_w ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Padding in width dimension.
--- ## conv_transpose_shape
`conv_transpose_shape[dtype: DType, strides_type: DType, dilations_type: DType, pads_type: DType, output_pads_type: DType, single_thread_blocking_override: Bool](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kernel: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides: LayoutTensor[strides_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations: LayoutTensor[dilations_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], pads: LayoutTensor[pads_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_pads: LayoutTensor[output_pads_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a `conv-transpose` operation, and assert the inputs are compatible. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element type of the input and kernel tensor. * ​strides\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element type of the strides tensor. * ​dilations\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element type of the dilations tensor. * ​pads\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element type of the pads tensor. * ​output\_pads\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Element type of the output\_pads tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​kernel ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The kernel tensor. * ​strides ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The strides tensor. * ​dilations ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The dilations tensor. * ​pads ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The paddings tensor. * ​output\_pads ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output paddings tensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## conv_transposed_cpu
`conv_transposed_cpu[input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, filter_packed: Bool, filter_is_cfrs: Bool, lambdas_have_fusion: Bool, elementwise_lambda: fn[dtype: DType, rank: Int, width: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None](output: LayoutTensor[output_type, output_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[input_type, input_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[filter_type, filter_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], stride: IndexList[(LayoutTensor[input_type, input_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)], dilation: IndexList[(LayoutTensor[input_type, input_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2])`
--- ## conv_transposed_cudnn
`conv_transposed_cudnn[input_type: DType, filter_type: DType, output_type: DType](input: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[filter_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], stride: IndexList[2], dilation: IndexList[2], padding: IndexList[2], ctx: DeviceContext)`
--- ## conv_transposed_gpu
`conv_transposed_gpu[input_layout: Layout, filter_layout: Layout, output_layout: Layout, input_type: DType, filter_type: DType, output_type: DType, elementwise_epilogue: OptionalReg[fn[dtype: DType, rank: Int, width: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None] = None](output: LayoutTensor[output_type, output_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[input_type, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[filter_type, filter_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], stride: IndexList[(LayoutTensor[input_type, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)], dilation: IndexList[(LayoutTensor[input_type, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)], padding: IndexList[(LayoutTensor[input_type, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank - 2)], ctx: DeviceContext)`
--- ## get_num_partitions
`get_num_partitions[micro_kernel_height: Int, micro_kernel_f_size: Int](num_threads: Int, conv_shape: ConvShape[rank]) -> IndexList[4]` Partition the workload in (batch\&group, C, F, H) dimensions. HOWO is the combination of HO and WO dimensions. The actual number of tasks are the product of return num\_partitions. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_partition
`get_partition(task_id: Int, num_partitions: IndexList[4], conv_shape: ConvShape[rank], micro_kernel_height: Int, micro_kernel_f_size: Int) -> ConvPartition` **Returns:** `ConvPartition`
--- ## conv_transpose
## Structs * [​`ConvTransposedPacked`](./ConvTransposedPacked): ## Functions * [​`accumulate_wo_tile`](./accumulate_wo_tile): * [​`conv_transpose_naive`](./conv_transpose_naive): Implements the ConvTranspose operator from the MO spec. * [​`conv_transpose_shape`](./conv_transpose_shape): Compute the output shape of a `conv-transpose` operation, and assert the inputs are compatible. * [​`conv_transposed_cpu`](./conv_transposed_cpu): * [​`conv_transposed_cudnn`](./conv_transposed_cudnn): * [​`conv_transposed_gpu`](./conv_transposed_gpu): * [​`get_num_partitions`](./get_num_partitions): Partition the workload in (batch\&group, C, F, H) dimensions. HOWO is the combination of HO and WO dimensions. The actual number of tasks are the product of return num\_partitions. * [​`get_partition`](./get_partition): * [​`pack_filter`](./pack_filter): This packs the filter form RSFC to FRSCf. * [​`pack_filter_shape`](./pack_filter_shape): Compute the output shape of transposed convolution filter packing. * [​`update_w_tile_2d`](./update_w_tile_2d): * [​`update_w_tile_3d`](./update_w_tile_3d):
--- ## pack_filter (Conv_transpose)
`pack_filter(filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], packed_filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int)` This packs the filter form RSFC to FRSCf.
--- ## pack_filter_shape (Conv_transpose)
`pack_filter_shape(filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int) -> IndexList[(LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank + 1)]` Compute the output shape of transposed convolution filter packing. **Args:** * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The filter to be packed. * ​num\_groups ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of groups in the convolution. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## update_w_tile_2d
`update_w_tile_2d[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, effected_by_padding: Bool, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], _init_output: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, conv_shape: ConvShape[2], n: Int, hw: IndexList[2])`
--- ## update_w_tile_3d
`update_w_tile_3d[micro_kernel_height: Int, micro_kernel_width: Int, simd_size: Int, effected_by_padding: Bool, has_residual: Bool, last_c_tile: Bool, output_dt: DType, input_dt: DType, filter_dt: DType](output: LegacyUnsafePointer[Scalar[output_dt]], input: LegacyUnsafePointer[Scalar[input_dt]], filter: LegacyUnsafePointer[Scalar[filter_dt]], _init_output: Bool, c_tile_size: Int, f_tile_offset: Int, f_tile_size: Int, conv_shape: ConvShape[3], n: Int, hw: IndexList[3])`
--- ## ConvAlgorithm
`@register_passable(trivial)` `struct ConvAlgorithm` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Default` `comptime Default = ConvAlgorithm(0)` ### `Direct` `comptime Direct = ConvAlgorithm(2)` ### `Im2Col` `comptime Im2Col = ConvAlgorithm(1)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## ConvInfoStatic
`struct ConvInfoStatic[rank: Int]` ## Fields * ​pad (`IntTuple`): * ​stride (`IntTuple`): * ​dilation (`IntTuple`): * ​num\_groups (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ## Methods ### `__init__` `__init__(out self, pad: IntTuple, stride: IntTuple, dilation: IntTuple, num_groups: Int)` `__init__(out self)` `__init__(out self, pad: IntTuple, stride: IntTuple, dilation: IntTuple, input_c: Int, filter_c: Int)` ### `all_known` `all_known(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `pad_left` `pad_left(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `pad_bottom` `pad_bottom(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `strides` `strides(self) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `dilations` `dilations(self) -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## ConvPartition
`@register_passable(trivial)` `struct ConvPartition` Work range for a partition. ## Fields * ​ng\_offset (`Int`): * ​ng\_size (`Int`): * ​f\_offset (`Int`): * ​f\_size (`Int`): * ​ho\_or\_howo\_offset (`Int`): * ​ho\_or\_howo\_size (`Int`): * ​c\_offset (`Int`): * ​c\_size (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `empty` `empty(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## ConvShape
`@register_passable(trivial)` `struct ConvShape[rank: Int]` A shape struct describing the convolution dimensions. ## Fields * ​n (`Int`): * ​input\_dims (`IndexList[rank]`): * ​output\_dims (`IndexList[rank]`): * ​filter\_dims (`IndexList[rank]`): * ​c (`Int`): * ​f (`Int`): * ​stride (`IndexList[rank]`): * ​dilation (`IndexList[rank]`): * ​pad\_d (`IndexList[2]`): * ​pad\_h (`IndexList[2]`): * ​pad\_w (`IndexList[2]`): * ​num\_groups (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `d` `d(self) -> Int` Input depth. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `h` `h(self) -> Int` Input height. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `w` `w(self) -> Int` Input width. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `do` `do(self) -> Int` Output depth. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `ho` `ho(self) -> Int` Output height. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `wo` `wo(self) -> Int` Output width. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `q` `q(self) -> Int` Filter window depth. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `r` `r(self) -> Int` Filter window height. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `s` `s(self) -> Int` Filter windown width. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `filter_window_flat_size` `filter_window_flat_size(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `input_image_flat_size` `input_image_flat_size(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `output_image_flat_size` `output_image_flat_size(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `output_space_dims` `output_space_dims(self) -> IndexList[rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList) ### `output_flat_coord_to_input_offset` `output_flat_coord_to_input_offset(self, n: Int, output_flat_coord: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `matmul_M` `matmul_M(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `matmul_N` `matmul_N(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `matmul_K` `matmul_K(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `padded` `padded(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `c_per_group` `c_per_group(self) -> Int` Returns the number of channels per group. Channel count must be divisible by group size. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `f_per_group` `f_per_group(self) -> Int` Returns the number of filters per group. Filter count must be divisible by group size. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `f_to_group` `f_to_group(self, f_idx: Int) -> Int` Given a global filter idx, returns the group idx of the group the filter belongs to. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `c_to_group` `c_to_group(self, c_idx: Int) -> Int` Given a global channel idx, returns the group idx of the group the channel belongs to. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `f_in_group` `f_in_group(self, f_idx: Int) -> Int` Given a global filter idx, returns the offset of the filter in its group. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `c_in_group` `c_in_group(self, c_idx: Int) -> Int` Given a global channel idx, returns the offset of the channel in its group. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## align_down_residual
`align_down_residual(value: Int, alignment: Int) -> Int` Returns the remainder after aligning down value to alignment. **Args:** * ​value ([`Int`](/mojo/stdlib/builtin/int/Int)): The value to align. * ​alignment ([`Int`](/mojo/stdlib/builtin/int/Int)): Value to align to. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The remainder after aligning down value to the closest multiple of alignment. In other words, value - align\_down(value, alignment).
--- ## append_shape
`append_shape[rank: Int](in_shape: IndexList[rank], last2nd: Int, last: Int) -> IndexList[(rank + 2)]` Append input shape by inserting `last2nd` and `last` at the end. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## extend_shape
`extend_shape[rank: Int](in_shape: IndexList[rank], first: Int, last: Int) -> IndexList[(rank + 2)]` Extend input shape by inserting `first` and `last` at both ends. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_conv2d_shape
`get_conv2d_shape[output_layout: Layout, input_layout: Layout, filter_layout_param: Layout, dtype: DType, data_layout: Image2DLayout, filter_layout: Image2DLayout](output: LayoutTensor[dtype, output_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[dtype, filter_layout_param, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], pad_h: IndexList[2], pad_w: IndexList[2], stride: IndexList[2], dilation: IndexList[2], num_groups: Int) -> ConvShape[2]` **Returns:** [`ConvShape`](/mojo/kernels/nn/conv_utils/ConvShape)
--- ## get_conv_num_partitions
`get_conv_num_partitions[micro_kernel_w: Int, micro_kernel_f: Int](num_threads: Int, conv_shape: ConvShape[rank]) -> IndexList[4]` Partition the workload in (batch, C, F, HOWO) dimensions. HOWO is the combination of HO and WO dimensions. The actual number of tasks are the product of return num\_partitions. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_conv_num_tasks
`get_conv_num_tasks(num_threads: Int, conv_shape: ConvShape[rank]) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_conv_shape
`get_conv_shape[rank: Int, filter_packed: Bool](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], stride: IndexList[rank], dilation: IndexList[rank], pad_d: IndexList[2], pad_h: IndexList[2], pad_w: IndexList[2], num_groups: Int) -> ConvShape[rank]` **Returns:** [`ConvShape`](/mojo/kernels/nn/conv_utils/ConvShape)
--- ## get_conv_tile_shape
`get_conv_tile_shape[dtype: DType](c: Int, filter_window_size: Int, micro_kernel_width: Int) -> IndexList[2]` Compute the (c, f) tile shape in L2. Assume NHWC layout, the tile shape is (R, S, c\_tile, f\_tile). R and S are by default fully covered. The heuristic tried to block in C as much as possible. If C is small, it would start to block F. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_conv_tile_size
`get_conv_tile_size[dtype: DType]() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_direct_conv_micro_kernel_height
`get_direct_conv_micro_kernel_height() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_direct_conv_micro_kernel_width
`get_direct_conv_micro_kernel_width() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## get_micro_kernel_shape
`get_micro_kernel_shape[rank: Int, WO: Int, F: Int, conv_attr: ConvInfoStatic[rank], simd_size: Int]() -> IndexList[2]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_partition (Conv_utils)
`get_partition(task_id: Int, num_partitions: IndexList[4], conv_shape: ConvShape[rank], micro_kernel_height: Int, micro_kernel_f_size: Int) -> ConvPartition` **Returns:** `ConvPartition`
--- ## conv_utils
## `comptime` values ### `elementwise_epilogue_type` `comptime elementwise_epilogue_type = fn[rank: Int](coords: IndexList[rank], f_size: Int) capturing -> None` ### `elementwise_simd_epilogue_type` `comptime elementwise_simd_epilogue_type = fn[dtype: DType, rank: Int, width: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None` ## Structs * [​`ConvAlgorithm`](./ConvAlgorithm): * [​`ConvInfoStatic`](./ConvInfoStatic): * [​`ConvPartition`](./ConvPartition): Work range for a partition. * [​`ConvShape`](./ConvShape): A shape struct describing the convolution dimensions. ## Functions * [​`align_down_residual`](./align_down_residual): Returns the remainder after aligning down value to alignment. * [​`append_shape`](./append_shape): Append input shape by inserting `last2nd` and `last` at the end. * [​`extend_shape`](./extend_shape): Extend input shape by inserting `first` and `last` at both ends. * [​`get_conv2d_shape`](./get_conv2d_shape): * [​`get_conv_num_partitions`](./get_conv_num_partitions): Partition the workload in (batch, C, F, HOWO) dimensions. HOWO is the combination of HO and WO dimensions. The actual number of tasks are the product of return num\_partitions. * [​`get_conv_num_tasks`](./get_conv_num_tasks): * [​`get_conv_shape`](./get_conv_shape): * [​`get_conv_tile_shape`](./get_conv_tile_shape): Compute the (c, f) tile shape in L2. Assume NHWC layout, the tile shape is (R, S, c\_tile, f\_tile). R and S are by default fully covered. The heuristic tried to block in C as much as possible. If C is small, it would start to block F. * [​`get_conv_tile_size`](./get_conv_tile_size): * [​`get_direct_conv_micro_kernel_height`](./get_direct_conv_micro_kernel_height): * [​`get_direct_conv_micro_kernel_width`](./get_direct_conv_micro_kernel_width): * [​`get_micro_kernel_shape`](./get_micro_kernel_shape): * [​`get_partition`](./get_partition): * [​`reorder_padding`](./reorder_padding):
--- ## reorder_padding
`reorder_padding[rank: Int](pad: IntTuple) -> IntTuple` **Returns:** [`IntTuple`](/mojo/kernels/layout/int_tuple/IntTuple)
--- ## cumsum
`cumsum[dtype: DType, exclusive: Bool, reverse: Bool](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int)` Implements the CumSum operator from the ONNX spec: Computes cumulative sum of the input elements along the given axis. Cumulative sum can be inclusive or exclusive of the top element, and normal or reverse (direction along a given axis). **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input and output tensors. * ​exclusive ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If set to True, return exclusive sum (top element not included). * ​reverse ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If set to True, perform cumsum operation in reverse direction. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis on which to perform the cumsum operation.
--- ## cumsum (Cumsum)
## Functions * [​`cumsum`](./cumsum): Implements the CumSum operator from the ONNX spec: Computes cumulative sum of the input elements along the given axis. Cumulative sum can be inclusive or exclusive of the top element, and normal or reverse (direction along a given axis).
--- ## flash_attention
`flash_attention[dtype: DType, rank: Int, mask_rank: Int, //, input_k_fn: fn[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_v_fn: fn[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_mask_fn: fn[simd_width: Int, mask_rank: Int](IndexList[mask_rank]) capturing -> SIMD[dtype, simd_width]](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_shape: IndexList[rank], v_shape: IndexList[rank], mask_shape: IndexList[mask_rank], output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## flash_attention_kv_cache
`flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, //](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, v: cache_t, mask: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)` `flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, mask_t: MHAMask, //](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, v: cache_t, mask: mask_t, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)` `flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, mask_t: MHAMask, //](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, v: cache_t, mask: mask_t, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)` Entrypoint for ragged tensors.
--- ## flash_attention_split_kv
`flash_attention_split_kv[dtype: DType, rank: Int, mask_rank: Int, //, input_k_fn: fn[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_v_fn: fn[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_k_cache_fn: fn[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_v_cache_fn: fn[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_mask_fn: fn[simd_width: Int, mask_rank: Int](IndexList[mask_rank]) capturing -> SIMD[dtype, simd_width]](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_shape: IndexList[rank], v_shape: IndexList[rank], k_cache_shape: IndexList[(rank + 1)], v_cache_shape: IndexList[(rank + 1)], mask_shape: IndexList[mask_rank], output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32)` Variant of flash attention that takes the previous KV cache `input_{k,v}_cache_fn` and the current KV tensors `input_k_fn` and `input_v_fn` as separate arguments. This works around the fact that fusion can't currently look through concat. So this kernel does an in-place concat fusion by changing the input lambdas `input_{k,v}_cache_fn_wrapper` to take previous sequence KV elements from the KV cache, and current KV elements from tensors `k` and `v`.
--- ## flash_attention (Flash_attention)
## Functions * [​`flash_attention`](./flash_attention): * [​`flash_attention_kv_cache`](./flash_attention_kv_cache): * [​`flash_attention_split_kv`](./flash_attention_split_kv): Variant of flash attention that takes the previous KV cache `input_{k,v}_cache_fn` and the current KV tensors `input_k_fn` and `input_v_fn` as separate arguments.
--- ## fold
`fold[dtype: DType, stride: Tuple[Int, Int], dilation: Tuple[Int, Int], padding: Tuple[Int, Int], target: StringSlice[StaticConstantOrigin]](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_size: IndexList[2], kernel_size: IndexList[2], ctx: DeviceContextPtr)` Folds array of sliding local blocks into a single output tensor. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for the input and output. * ​stride ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Stride of the sliding blocks. * ​dilation ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Dilation of the sliding blocks. * ​padding ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): 0-paddings to be added on both sides of the inputs. * ​target (`StringSlice`): The target architecture to compile for. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input tensor to fold, shape \[N, C x kernel size, num\_blocks]. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor to write to, shape \[N, C, H, W]. * ​output\_size ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Spatial shape of the output tensor (H, W). * ​kernel\_size ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Size of the sliding blocks. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): DeviceContextPtr.
--- ## fold_shape
`fold_shape[dtype: DType](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_size: IndexList[2], kernel_size: IndexList[2]) -> IndexList[4]` Returns the shape of the output tensor of the fold operation. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## fold (Fold)
Implements the fold operation. ## Functions * [​`fold`](./fold): Folds array of sliding local blocks into a single output tensor. * [​`fold_shape`](./fold_shape): Returns the shape of the output tensor of the fold operation.
--- ## fused_qk_rope
`fused_qk_rope[dtype: DType, collection_t: KVCollectionT, //, cache_t: KVCacheT, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin]](q_proj: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, freqs_cis: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], layer_idx: UInt32, valid_lengths: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: Optional[DeviceContext])` Applies RoPE to query and key tensors. **Args:** * ​q\_proj ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Query projection tensor of shape \[batch, seq\_len, n\_heads, head\_dim]. * ​kv\_collection (`collection_t`): The KV cache collection containing the key cache. * ​freqs\_cis ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Frequency tensor for RoPE of shape \[max\_seq\_len, head\_dim]. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The layer index for accessing the correct cache. * ​valid\_lengths ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of shape \[batch] containing the valid length for each sequence. RoPE is only applied to positions within these lengths. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor for Q with RoPE applied, same shape as q\_proj. * ​context ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Optional device context for GPU execution.
--- ## fused_qk_rope_ragged
`fused_qk_rope_ragged[dtype: DType, freq_dtype: DType, collection_t: KVCollectionT, //, cache_t: KVCacheT, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin], mrope_section: Optional[IntTuple] = None](q_proj: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, freqs_cis: LayoutTensor[freq_dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], position_ids: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major[2](), MutAnyOrigin]], layer_idx: UInt32, output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: Optional[DeviceContext])` Applies RoPE (Rotary Position Embedding) to query and key tensors. This function can applies RoPE only to the last `rope_dim` elements of each head, leaving the first `unroped_dim` elements unchanged. This is required for DeepSeek models where only part of each head undergoes rotary transformation.
--- ## get_identity_rope_coeff
`get_identity_rope_coeff[width: Int, dtype: DType]() -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## get_safetensors_idx
`get_safetensors_idx(head_dim_idx: Int, head_size: Int) -> Tuple[Int, Int]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)
--- ## fused_qk_rope (Fused_qk_rope)
## Functions * [​`fused_qk_rope`](./fused_qk_rope): Applies RoPE to query and key tensors. * [​`fused_qk_rope_ragged`](./fused_qk_rope_ragged): Applies RoPE (Rotary Position Embedding) to query and key tensors. * [​`get_identity_rope_coeff`](./get_identity_rope_coeff): * [​`get_safetensors_idx`](./get_safetensors_idx): * [​`rope_k_cache`](./rope_k_cache): * [​`rope_q_proj`](./rope_q_proj):
--- ## rope_k_cache
`rope_k_cache[freq_dtype: DType, cache_t: KVCacheT, width: Int, //, *, interleaved: Bool](k_cache: cache_t, b_idx: Int, h_idx: Int, s_idx: Int, d_idx: Int, freq_val: SIMD[freq_dtype, width], head_size: Int)`
--- ## rope_q_proj
`rope_q_proj[dtype: DType, freq_dtype: DType, rank: Int, width: Int, //, *, interleaved: Bool](q_proj: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], idx: IndexList[rank], freq_val: SIMD[freq_dtype, width], head_size: Int)`
--- ## Axis
`@register_passable(trivial)` `struct Axis` ## Fields * ​axis (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Indexer`](/mojo/stdlib/builtin/int/Indexer), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(axis: Int) -> Self` `__init__(out self, axis: Int, rank: Int)` ### `__int__` `__int__(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `__mlir_index__` `__mlir_index__(self) -> __mlir_type.index` Convert to index. **Returns:** `__mlir_type.index`: The corresponding \_\_mlir\_type.index value.
--- ## gather
`gather[dtype: DType, indices_type: DType, //, *, axis: Int, target: StringSlice[StaticConstantOrigin] = "cpu"](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], *, context: DeviceContext)` Gather operation as defined in . Note that this is NOT the same as the default PyTorch gather (which is equivalent to ). `gather[dtype: DType, indices_type: DType, //, *, axis: Int, target: StringSlice[StaticConstantOrigin] = "cpu"](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], *, context: DeviceContextPtr = DeviceContextPtr())` Gather operation as defined in . Note that this is NOT the same as the default PyTorch gather (which is equivalent to ). `gather[*, dtype: DType, indices_type: DType, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], indices_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[indices_type, width], output_fn: fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, prefetch_fn: OptionalReg[fn[input_rank: Int, indices_rank: Int](IndexList[input_rank], IndexList[indices_rank]) capturing -> None] = None, target: StringSlice[StaticConstantOrigin] = "cpu", single_thread_blocking_override: Bool = False](axis: Axis, input_shape: IndexList[size, element_type=element_type], indices_shape: IndexList[size, element_type=element_type], output_shape: IndexList[size, element_type=element_type], *, context: DeviceContext)` Gather operation as defined in . Note that this is NOT the same as the default PyTorch gather (which is equivalent to ). `gather[*, dtype: DType, indices_type: DType, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], indices_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[indices_type, width], output_fn: fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, prefetch_fn: OptionalReg[fn[input_rank: Int, indices_rank: Int](IndexList[input_rank], IndexList[indices_rank]) capturing -> None] = None, target: StringSlice[StaticConstantOrigin] = "cpu", single_thread_blocking_override: Bool = False](axis: Axis, input_shape: IndexList[size, element_type=element_type], indices_shape: IndexList[size, element_type=element_type], output_shape: IndexList[size, element_type=element_type], *, context: DeviceContextPtr = DeviceContextPtr())` Gather operation as defined in . Note that this is NOT the same as the default PyTorch gather (which is equivalent to ).
--- ## gather_elements
`gather_elements[input_type: DType, indices_type: DType](input: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], _axis: Int, output: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Implements ONNX GatherElements op which is equivalent to Pytorch gather.
--- ## gather_elementwise_fn_wrapper
`gather_elementwise_fn_wrapper[*, dtype: DType, indices_type: DType, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], indices_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[indices_type, width], output_fn: fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, simd_width: Int, prefetch_fn: OptionalReg[fn[input_rank: Int, indices_rank: Int](IndexList[input_rank], IndexList[indices_rank]) capturing -> None] = None, error_index_fn: OptionalReg[fn(Int) capturing -> None] = None](axis: Axis, input_shape: IndexList[size, element_type=element_type], indices_shape: IndexList[size, element_type=element_type], output_shape: IndexList[size, element_type=element_type], coords: IndexList[size, element_type=element_type])`
--- ## gather_guards
`gather_guards(axis: Axis, input_shape: IndexList[size, element_type=element_type], indices_shape: IndexList[size, element_type=element_type], output_shape: IndexList[size, element_type=element_type])`
--- ## gather_nd
`gather_nd[dtype: DType, indices_type: DType, batch_dims: Int, target: StringSlice[StaticConstantOrigin] = "cpu", single_thread_blocking_override: Bool = False](data: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` GatherND operation as defined in . Based on reference implementation: . **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of data tensor. * ​indices\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of indices tensor. * ​batch\_dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of batch dimensions. The gather of indexing starts from dimension of data\[batch\_dims:]. * ​target (`StringSlice`): The target architecture to execute on. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​data ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank data\_rank >= 1. * ​indices ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank indices\_rank >= 1. All index values are expected to be within bounds \[-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank data\_rank + indices\_rank - indices\_shape\[-1] - 1 - b. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The DeviceContextPtr as prepared by the graph compiler.
--- ## gather_nd_shape
`gather_nd_shape[output_rank: Int, input_type: DType, indices_type: DType, batch_dims: Int, single_thread_blocking_override: Bool = True](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices_buf: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[output_rank]` Compute the output shape of a `gather` operation, and assert the inputs are compatible. **Parameters:** * ​output\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): Rank of the output tensor. * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​indices\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the indices tensor. * ​batch\_dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Batch dimensions. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then reduction is run synchronously using a single thread. **Args:** * ​input\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​indices\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The indices tensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## gather_reduce
`gather_reduce[dtype: DType, gather_axis: Int, reduce_axis: Int, simd_width: Int, reduce_fn: fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) -> SIMD[dtype, width]](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[DType.int32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], reduce_init: Scalar[dtype])` Computes output\[i, j, k] = input\[indices\[i, j], k] and simultaneously reduces the output across axis 1 to produce output\[i, k]. The motivating use-case for this is multi-hot embeddings in recommender models. This provides similar functionality to Torch's EmbeddingBag layer. In that context, i is the batch dimension, j is the multi-hot dimension, and k is the embedding dimension.
--- ## gather_shape
`gather_shape[output_rank: Int, input_type: DType, indices_type: DType, single_thread_blocking_override: Bool = False](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices_buf: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int) -> IndexList[output_rank]` Compute the output shape of a `gather` operation, and assert the inputs are compatible. **Parameters:** * ​output\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): Rank of the output tensor. * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​indices\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the indices tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​indices\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The indices tensor. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## gather_scatter
## `comptime` values ### `error_index_fn_type` `comptime error_index_fn_type = fn(Int) capturing -> None` ## Structs * [​`Axis`](./Axis): ## Functions * [​`gather`](./gather): Gather operation as defined in . * [​`gather_elements`](./gather_elements): Implements ONNX GatherElements op which is equivalent to Pytorch gather. * [​`gather_elementwise_fn_wrapper`](./gather_elementwise_fn_wrapper): * [​`gather_guards`](./gather_guards): * [​`gather_nd`](./gather_nd): GatherND operation as defined in . Based on reference implementation: . * [​`gather_nd_shape`](./gather_nd_shape): Compute the output shape of a `gather` operation, and assert the inputs are compatible. * [​`gather_reduce`](./gather_reduce): Computes output\[i, j, k] = input\[indices\[i, j], k] and simultaneously reduces the output across axis 1 to produce output\[i, k]. * [​`gather_shape`](./gather_shape): Compute the output shape of a `gather` operation, and assert the inputs are compatible. * [​`normalize_neg_index`](./normalize_neg_index): Indices passed to gather and scatter ops may be negative. This performs a normalization so that they can be used to index into a buffer. * [​`scatter_elements`](./scatter_elements): Implements ONNX ScatterElements op which is equivalent to Pytorch scatter. * [​`scatter_elements_shape`](./scatter_elements_shape): Compute the output shape of a `scatter_elements` operation, and assert the inputs are compatible. * [​`scatter_nd`](./scatter_nd): Scatter\_nd operation without any reduction. * [​`scatter_nd_generator`](./scatter_nd_generator): Implements ONNX ScatterND operation as defined in . * [​`scatter_nd_shape`](./scatter_nd_shape): Compute the output shape of a `scatter_nd` operation, and assert the inputs are compatible. * [​`scatter_set_constant`](./scatter_set_constant): Scatter the fill\_value into the data at the specified indices.
--- ## normalize_neg_index
`normalize_neg_index(idx: Int, dim_size: Int) -> Int` Indices passed to gather and scatter ops may be negative. This performs a normalization so that they can be used to index into a buffer. Returns val + dim if val < 0 else val **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) `normalize_neg_index[dtype: DType, width: Int, out_type: DType = DType.index](idx: SIMD[dtype, width], dim_size: Int) -> SIMD[out_type, width]` Indices passed to gather and scatter ops may be negative. This performs a normalization so that they can be used to index into a buffer. Returns val + dim if val < 0 else val **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## scatter_elements
`scatter_elements[reduce_fn: fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width], rank: Int, input_type: DType, indices_type: DType](input: ManagedTensorSlice[io_spec, static_spec=static_spec], indices: ManagedTensorSlice[io_spec, static_spec=static_spec], updates: ManagedTensorSlice[io_spec, static_spec=static_spec], _axis: Int, output: ManagedTensorSlice[io_spec, static_spec=static_spec])` Implements ONNX ScatterElements op which is equivalent to Pytorch scatter.
--- ## scatter_elements_shape
`scatter_elements_shape[input_type: DType, indices_type: DType, //, *, single_thread_blocking_override: Bool](input: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], updates: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int) -> IndexList[LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a `scatter_elements` operation, and assert the inputs are compatible. **Parameters:** * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​indices\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the indices tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​updates ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​indices ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The indices tensor. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## scatter_nd
`scatter_nd[output_type: DType, indices_type: DType, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = "cpu"](data: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], updates: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr = DeviceContextPtr())` Scatter\_nd operation without any reduction.
--- ## scatter_nd_generator
`scatter_nd_generator[output_type: DType, indices_type: DType, single_thread_blocking_override: Bool, target: StringSlice[StaticConstantOrigin] = "cpu", /, reduce_fn: OptionalReg[fn[dtype: DType, width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]] = None, *, _trace_description: StringSlice[StaticConstantOrigin] = "scatter_nd"](data: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], updates: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr = DeviceContextPtr())` Implements ONNX ScatterND operation as defined in . **Parameters:** * ​output\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of data, updates, and output tensors. * ​indices\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the indices tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​target (`StringSlice`): Target cpu or cuda. * ​reduce\_fn ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Reduction function to apply: none (default), add, mul, max, min. * ​\_trace\_description (`StringSlice`): A description of the function, used for profiling and tracing. **Args:** * ​data ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank data\_rank >= 1. * ​indices ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank indices\_rank containing indices for the scatter operation. * ​updates ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor containing values to update output tensor based on indices tensor. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank data\_rank, shaped the same as data tensor. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): Pointer to DeviceContext.
--- ## scatter_nd_shape
`scatter_nd_shape[input_type: DType, indices_type: DType, single_thread_blocking_override: Bool](input: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], updates: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a `scatter_nd` operation, and assert the inputs are compatible. **Parameters:** * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​indices\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the indices tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​updates ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​indices ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The indices tensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## scatter_set_constant
`scatter_set_constant[data_type: DType, index_type: DType, //, target: StringSlice[StaticConstantOrigin], single_thread_blocking_override: Bool = False](data: LayoutTensor[data_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[index_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], fill_value: Scalar[data_type], ctx: DeviceContextPtr)` Scatter the fill\_value into the data at the specified indices. Example: Suppose we have a 3x3 matrix `data` initialized to zeros: data = [\[0, 0, 0], \[0, 0, 0], \[0, 0, 0]] And `indices` is a 2D tensor with shape \[2, 2]: indices = [\[0, 1], \[2, 0]] If `fill_value` is 5, after calling `scatter_set_constant`, `data` will be: data = [\[0, 5, 0], \[0, 0, 0], \[5, 0, 0]] Arguments: data: The data to scatter the updates into. indices: The indices to scatter the updates into. fill\_value: The value to fill the data with. ctx: The device context.
--- ## Image2DLayout
`@register_passable(trivial)` `struct Image2DLayout` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `FRSCf` `comptime FRSCf = Image2DLayout(3)` ### `NCHW` `comptime NCHW = Image2DLayout(1)` ### `NHWC` `comptime NHWC = Image2DLayout(0)` ### `RSCF` `comptime RSCF = Image2DLayout(2)` ### `UNKNOWN` `comptime UNKNOWN = Image2DLayout(-1)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## ImageData
`@register_passable(trivial)` `struct ImageData[layout: Layout, dtype: DType, static_image_layout: Image2DLayout, origin: MutOrigin]` Utility class that generalizes conv2d data and filter tensor with a given data layout. ## Fields * ​data (`LayoutTensor[dtype, layout, origin]`): * ​dynamic\_image\_layout (`Image2DLayout`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(data: LayoutTensor[dtype, layout, origin], _layout: Image2DLayout) -> Self` Construct of an image data instance with dynamic layout param. **Args:** * ​data ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): A 4d buffer containing the actual data. * ​\_layout ([`Image2DLayout`](/mojo/kernels/nn/image/Image2DLayout)): Data layout tag. `__init__(data: LayoutTensor[dtype, layout, origin]) -> Self` ### `__getitem__` `__getitem__(self, n: Int, c: Int, h: Int, w: Int) -> Scalar[dtype]` Reads the underlying data buffer based on the tensor index and under- lying data layout. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the batch dimension. * ​c ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the channel dimension. * ​h ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the height dimension. * ​w ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the width dimension. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The value stored at the given index position. ### `__setitem__` `__setitem__(self, n: Int, c: Int, h: Int, w: Int, value: Scalar[dtype])` Writes the underlying data buffer based on the tensor index and under- lying data layout. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the batch dimension. * ​c ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the channel dimension. * ​h ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the height dimension. * ​w ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the width dimension. * ​value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value to store at the given index position. ### `to_static_layout` `to_static_layout[new_static_image_layout: Image2DLayout](self) -> ImageData[layout, dtype, new_static_image_layout, origin]` Conversion utility from a fully dynamic data structure, e.g. from c shim to one with compile-time known data layout. **Returns:** `ImageData`: The image data with static data layout. ### `get_image_layout` `get_image_layout(self) -> Image2DLayout` The getter function of the underlying data layout, resolving from either statically or dynamically provided information. **Returns:** [`Image2DLayout`](/mojo/kernels/nn/image/Image2DLayout): The resolved data layout tag for this image instance. ### `get_flat_index` `get_flat_index(self, n: Int, c: Int, h: Int, w: Int) -> Int` Converts the dimension index to the flat index of the underlying data based on the tensor layout. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the batch dimension. * ​c ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the channel dimension. * ​h ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the height dimension. * ​w ([`Int`](/mojo/stdlib/builtin/int/Int)): Index on the width dimension. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer containing the index based on the underlying data layout. ### `get_tuple_index` `get_tuple_index(self, idx: Int) -> IndexList[4]` Converts the flat index to the dimension index of the underlying data based on the tensor layout. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Flat index. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): A IndexList containing the index in NCHW order. ### `num_elements` `num_elements(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## ImageShape
`@register_passable(trivial)` `struct ImageShape` A data-layout agnostic representation of tensor shapes used in conv2d. ## Fields * ​N (`Int`): * ​C (`Int`): * ​H (`Int`): * ​W (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__[layout: Layout, dtype: DType, image_layout: Image2DLayout](image_data: ImageData[layout, dtype, image_layout, origin]) -> Self` Constructor of an ImageShape instance from an ImageData. **Args:** * ​image\_data ([`ImageData`](/mojo/kernels/nn/image/ImageData)): The image data instance to extract shape info from.
--- ## PadHandling
`@register_passable(trivial)` `struct PadHandling` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `EXCLUDE_PAD` `comptime EXCLUDE_PAD = PadHandling(0)` ### `INCLUDE_PAD` `comptime INCLUDE_PAD = PadHandling(2)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## image
## Structs * [​`Image2DLayout`](./Image2DLayout): * [​`ImageData`](./ImageData): Utility class that generalizes conv2d data and filter tensor with a given data layout. * [​`ImageShape`](./ImageShape): A data-layout agnostic representation of tensor shapes used in conv2d. * [​`PadHandling`](./PadHandling):
--- ## nn
Provides neural network operators for deep learning models. ## Packages * [​`attention`](./attention/): Attention operations. ## Modules * [​`activations`](./activations/): The module contains implementations of activation functions. * [​`arange`](./arange/): * [​`arg_nonzero`](./arg_nonzero/): * [​`argmaxmin`](./argmaxmin/): * [​`argmaxmin_gpu`](./argmaxmin_gpu/): * [​`argsort`](./argsort/): * [​`bicubic`](./bicubic/): This module provides CPU and GPU implementations for bicubic interpolation. * [​`broadcast`](./broadcast/): * [​`concat`](./concat/): * [​`conv`](./conv/): * [​`conv_transpose`](./conv_transpose/): * [​`conv_utils`](./conv_utils/): * [​`cumsum`](./cumsum/): * [​`flash_attention`](./flash_attention/): * [​`fold`](./fold/): Implements the fold operation. * [​`fused_qk_rope`](./fused_qk_rope/): * [​`gather_scatter`](./gather_scatter/): * [​`image`](./image/): * [​`index_tensor`](./index_tensor/): * [​`irfft`](./irfft/): Inverse real FFT kernel using cuFFT. * [​`kv_cache`](./kv_cache/): * [​`kv_cache_ragged`](./kv_cache_ragged/): * [​`mha`](./mha/): * [​`mha_cross`](./mha_cross/): * [​`mha_fa3_utils`](./mha_fa3_utils/): * [​`mha_mask`](./mha_mask/): * [​`mha_operand`](./mha_operand/): * [​`mha_score_mod`](./mha_score_mod/): * [​`mha_sm100_1q`](./mha_sm100_1q/): * [​`mha_sm100_2q`](./mha_sm100_2q/): * [​`mha_sm90`](./mha_sm90/): * [​`mha_tile_scheduler`](./mha_tile_scheduler/): * [​`mha_utils`](./mha_utils/): * [​`mla`](./mla/): * [​`mla_graph`](./mla_graph/): * [​`mla_prefill_sm100`](./mla_prefill_sm100/): * [​`moe`](./moe/): * [​`nms`](./nms/): * [​`normalization`](./normalization/): * [​`pad`](./pad/): * [​`pad_gpu`](./pad_gpu/): * [​`pool`](./pool/): * [​`rand_normal`](./rand_normal/): * [​`rand_uniform`](./rand_uniform/): * [​`randn`](./randn/): * [​`repeat_interleave`](./repeat_interleave/): * [​`reshape`](./reshape/): * [​`resize`](./resize/): * [​`roi_align`](./roi_align/): * [​`rope`](./rope/): * [​`sampling`](./sampling/): * [​`shapes`](./shapes/): * [​`slice`](./slice/): * [​`softmax`](./softmax/): * [​`spatial_merge`](./spatial_merge/): * [​`split`](./split/): * [​`tile`](./tile/): * [​`topk`](./topk/): * [​`topk_fi`](./topk_fi/): * [​`toppminp`](./toppminp/): * [​`toppminp_gpu`](./toppminp_gpu/):
--- ## advanced_indexing_getitem
`advanced_indexing_getitem[input_rank: Int, index_rank: Int, input_type: DType, index_type: DType, //, start_axis: Int, num_index_tensors: Int, target: StringSlice[StaticConstantOrigin], single_thread_blocking_override: Bool, trace_description: StringSlice[StaticConstantOrigin], input_tensor_fn: fn[width: Int](IndexList[input_rank]) capturing -> SIMD[input_type, width], indices_fn: fn[indices_index: Int](IndexList[index_rank]) capturing -> Scalar[index_type]](out_tensor: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], in_tensor_strides: IndexList[input_rank], ctx: DeviceContextPtr)` Implement basic numpy-style advanced indexing. This is designed to be fused with other view-producing operations to implement full numpy-indexing semantics. This assumes the dimensions in `input_tensor` not indexed by index tensors are ":", ie selecting all indices along the slice. For example in numpy: ``` # rank(indices1) == 3 # rank(indices2) == 3 out_tensor = input_tensor[:, :, :, indices1, indices2, :, :] ``` We calculate the following for all valid valued indexing variables: ``` out_tensor[a, b, c, i, j, k, d, e] = input_tensor[ a, b, c, indices1[i, j, k], indices2[i, j, k], d, e ] ``` In this example `start_axis = 3` and `num_index_tensors = 2`. TODO(GEX-1951): Support boolean tensor mask support TODO(GEX-1952): Support non-contiguous indexing tensor case TODO(GEX-1953): Support fusion (especially view-fusion) **Parameters:** * ​input\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the input tensor. * ​index\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the indexing tensors. * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input tensor. * ​index\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the indexing tensors. * ​start\_axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The first dimension in input where the indexing tensors are applied. It is assumed the indexing tensors are applied in consecutive dimensions. * ​num\_index\_tensors ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of indexing tensors. * ​target (`StringSlice`): The target architecture to operation on. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​trace\_description (`StringSlice`): For profiling, the trace name the operation will appear under. * ​input\_tensor\_fn (`fn[width: Int](IndexList[input_rank]) capturing -> SIMD[input_type, width]`): Fusion lambda for the input tensor. * ​indices\_fn (`fn[indices_index: Int](IndexList[index_rank]) capturing -> Scalar[index_type]`): Fusion lambda for the indices tensors. **Args:** * ​out\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor to write to. * ​in\_tensor\_strides ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The strides of the input tensor. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The DeviceContextPtr as prepared by the graph compiler.
--- ## advanced_indexing_getitem_shape
`advanced_indexing_getitem_shape[input_rank: Int, index_rank: Int, //, start_axis: Int, num_index_tensors: Int](input_shape: IndexList[input_rank], index_shape: IndexList[index_rank]) -> IndexList[((input_rank + index_rank) - num_index_tensors)]` Calculate the output shape from advanced indexing. **Parameters:** * ​input\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the input tensor. * ​index\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the indexing tensors. * ​start\_axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The first dimension in input where the indexing tensors are applied. It is assumed the indexing tensors are applied in consecutive dimensions. * ​num\_index\_tensors ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of indexing tensors. **Args:** * ​input\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the input tensor in the operation. * ​index\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the indexing tensors in the operation. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## advanced_indexing_setitem_inplace
`advanced_indexing_setitem_inplace[index_rank: Int, updates_rank: Int, input_type: DType, index_type: DType, //, start_axis: Int, num_index_tensors: Int, target: StringSlice[StaticConstantOrigin], single_thread_blocking_override: Bool, trace_description: StringSlice[StaticConstantOrigin], updates_tensor_fn: fn[width: Int](IndexList[updates_rank]) capturing -> SIMD[input_type, width], indices_fn: fn[indices_index: Int](IndexList[index_rank]) capturing -> Scalar[index_type]](input_tensor: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], index_tensor_shape: IndexList[index_rank], updates_tensor_strides: IndexList[updates_rank], ctx: DeviceContextPtr)` Implement basic numpy-style advanced indexing with assignment. This is designed to be fused with other view-producing operations to implement full numpy-indexing semantics. This assumes the dimensions in `input_tensor` not indexed by index tensors are ":", ie selecting all indices along the slice. For example in numpy: ``` # rank(indices1) == 2 # rank(indices2) == 2 # rank(updates) == 2 input_tensor[:, :, :, indices1, indices2, :, :] = updates ``` We calculate the following for all valid valued indexing variables: ``` input_tensor[ a, b, c, indices1[i, j], indices2[i, j], d, e ] = updates[i, j] ``` In this example `start_axis = 3` and `num_index_tensors = 2`. In terms of implementation details, our strategy is to iterate over all indices over a common iteration range. The idea is we can map indices in this range to the write location in `input_tensor` as well as the data location in `updates`. An update can illustrate how this is possible best: Imagine the `input_tensor` shape is \[A, B, C, D] and we have indexing tensors I1 and I2 with shape \[M, N, K]. Assume I1 and I2 are applied to dimensions 1 and 2. I claim an appropriate common iteration range is then (A, M, N, K, D). Note we expect `updates` to be the shape \[A, M, N, K, D]. We will show this by providing the mappings into `updates` and `input_tensor`: Consider an arbitrary set of indices in this range (a, m, n, k, d): \- The index into `updates` is (a, m, n, k, d). \- The index into `input_tensor` is (a, I1\[m, n, k], I2\[m, n, k], d). TODO(GEX-1951): Support boolean tensor mask support TODO(GEX-1952): Support non-contiguous indexing tensor case TODO(GEX-1953): Support fusion (especially view-fusion) TODO(GEX-1954): Unify getitem and setitem using generic views. (Requires non-strided view functions). **Parameters:** * ​index\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the indexing tensors. * ​updates\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the updates tensor. * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input tensor. * ​index\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the indexing tensors. * ​start\_axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The first dimension in input where the indexing tensors are applied. It is assumed the indexing tensors are applied in consecutive dimensions. * ​num\_index\_tensors ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of indexing tensors. * ​target (`StringSlice`): The target architecture to operation on. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​trace\_description (`StringSlice`): For profiling, the trace name the operation will appear under. * ​updates\_tensor\_fn (`fn[width: Int](IndexList[updates_rank]) capturing -> SIMD[input_type, width]`): Fusion lambda for the update tensor. * ​indices\_fn (`fn[indices_index: Int](IndexList[index_rank]) capturing -> Scalar[index_type]`): Fusion lambda for the indices tensors. **Args:** * ​input\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor being indexed into and modified in-place. * ​index\_tensor\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of each index tensor. * ​updates\_tensor\_strides ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The strides of the update tensor. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The DeviceContextPtr as prepared by the graph compiler.
--- ## index_tensor
## Functions * [​`advanced_indexing_getitem`](./advanced_indexing_getitem): Implement basic numpy-style advanced indexing. * [​`advanced_indexing_getitem_shape`](./advanced_indexing_getitem_shape): Calculate the output shape from advanced indexing. * [​`advanced_indexing_setitem_inplace`](./advanced_indexing_setitem_inplace): Implement basic numpy-style advanced indexing with assignment. * [​`index_tensor`](./index_tensor): Index\_tensor operation; based on modified implementation of gather\_nd. * [​`index_tensor_shape`](./index_tensor_shape): Compute the output shape of a `index_tensor` operation, and assert the inputs are compatible.
--- ## index_tensor (Index_tensor)
`index_tensor[dtype: DType, indices_type: DType, batch_dims: Int, target: StringSlice[StaticConstantOrigin] = "cpu", single_thread_blocking_override: Bool = False](data: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Index\_tensor operation; based on modified implementation of gather\_nd. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of data tensor. * ​indices\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of indices tensor. * ​batch\_dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of batch dimensions. The gather of indexing starts from dimension of data\[batch\_dims:]. * ​target (`StringSlice`): The target architecture to execute on. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​data ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank data\_rank >= 1. * ​indices ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank indices\_rank >= 1. All index values are expected to be within bounds \[-s, s-1] along axis of size s. It is an error if any of the index values are out of bounds. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of rank data\_rank + indices\_rank - indices\_shape\[-1] - 1 - b. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The DeviceContextPtr as prepared by the graph compiler.
--- ## index_tensor_shape
`index_tensor_shape[output_rank: Int, input_type: DType, indices_type: DType, batch_dims: Int, single_thread_blocking_override: Bool = True](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], indices_buf: LayoutTensor[indices_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[output_rank]` Compute the output shape of a `index_tensor` operation, and assert the inputs are compatible. **Parameters:** * ​output\_rank ([`Int`](/mojo/stdlib/builtin/int/Int)): Rank of the output tensor. * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​indices\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the indices tensor. * ​batch\_dims ([`Int`](/mojo/stdlib/builtin/int/Int)): Batch dimensions. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then reduction is run synchronously using a single thread. **Args:** * ​input\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​indices\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The indices tensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## global_cache_insert
`global_cache_insert(key: String, value: LegacyUnsafePointer[NoneType])`
--- ## global_cache_lookup
`global_cache_lookup(key: String) -> LegacyOpaquePointer` **Returns:** `LegacyOpaquePointer`
--- ## irfft
Inverse real FFT kernel using cuFFT. ## Functions * [​`global_cache_insert`](./global_cache_insert): * [​`global_cache_lookup`](./global_cache_lookup): * [​`irfft`](./irfft): Compute the inverse real FFT of the input tensor.
--- ## irfft (Irfft)
`irfft[input_type: DType, output_type: DType, alignment: Int](input: LayoutTensor[input_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], n: Int, buffer_size_mb: Int, ctx: DeviceContext)` Compute the inverse real FFT of the input tensor. Currently, only applies it to the last dimension. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Complex input tensor (NDBuffer). * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Real output tensor (NDBuffer). * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Output signal size (if <= 0, computed as 2\*(input.size(axis) - 1)). * ​buffer\_size\_mb ([`Int`](/mojo/stdlib/builtin/int/Int)): Estimated buffer size in MB. * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): Device context.
--- ## generic_flash_attention_kv_cache_padded
`generic_flash_attention_kv_cache_padded[collection_t: KVCollectionT, dtype: DType, //, *, target: StringSlice[StaticConstantOrigin], mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], local_window_size: Int = -1, num_heads: Int = -1](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, layer_idx: UInt32, valid_lengths: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## generic_flash_attention_kv_cache_padded_materialized_mask
`generic_flash_attention_kv_cache_padded_materialized_mask[collection_t: KVCollectionT, dtype: DType, //, *, target: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], local_window_size: Int = -1, num_heads: Int = -1](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, layer_idx: UInt32, mask: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], valid_lengths: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## generic_fused_qk_rope_bshd_continuous_batch
`generic_fused_qk_rope_bshd_continuous_batch[dtype: DType, //, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin]](q_proj: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: ContinuousBatchingKVCacheCollection[dtype_, kv_params_], freqs_cis: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], layer_idx: UInt32, valid_lengths: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr = DeviceContextPtr())` Performs a fused RoPE projection for Q and K projections. We have a manually fused QKV projection with mo.opaque dtypes in our Llama model. Due to a limitation in custom op definitions, we can't declare both a tensor and opaque dtype as output from a custom kernel. This requires us to only note Q\_proj as an output from the QKV projection. If we immediately follow the QKV proj kernel with a RoPE kernel applied to K, we'll get a race condition because the graph compiler doesn't know about the dependency between these kernels in the graph definition. Here we fuse the RoPE kernel applied to Q\_proj with K\_proj, so K\_proj RoPE is only executed after QKV completes. **Args:** * ​q\_proj ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Query projection tensor of shape \[batch, seq\_len, n\_heads, head\_dim]. * ​kv\_collection ([`ContinuousBatchingKVCacheCollection`](/mojo/kernels/kv_cache/types/ContinuousBatchingKVCacheCollection)): The continuous batching KV cache collection. * ​freqs\_cis ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Frequency tensor for RoPE of shape \[max\_seq\_len, head\_dim]. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The layer index for accessing the correct cache. * ​valid\_lengths ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of shape \[batch] containing the valid length for each sequence. RoPE is only applied to positions within these lengths. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor for Q with RoPE applied, same shape as q\_proj. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): Device context pointer for execution.
--- ## generic_fused_qk_rope_bshd_paged
`generic_fused_qk_rope_bshd_paged[dtype: DType, //, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin]](q_proj: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], freqs_cis: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], layer_idx: UInt32, valid_lengths: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr = DeviceContextPtr())` Performs a fused RoPE projection for Q and K with paged KV cache. This is the paged equivalent of generic\_fused\_qk\_rope\_bshd\_continuous\_batch. It applies RoPE to both Q (returned) and K (in paged cache) to ensure proper dependency ordering after fused\_qkv\_padded\_matmul. **Args:** * ​q\_proj ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Query projection tensor of shape \[batch, seq\_len, n\_heads, head\_dim]. * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The paged KV cache collection. * ​freqs\_cis ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Frequency tensor for RoPE of shape \[max\_seq\_len, head\_dim]. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The layer index for accessing the correct cache. * ​valid\_lengths ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of shape \[batch] containing the valid length for each sequence. RoPE is only applied to positions within these lengths. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor for Q with RoPE applied, same shape as q\_proj. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): Device context pointer for execution.
--- ## generic_fused_qkv_matmul_kv_cache_bshd_continuous_batch
`generic_fused_qkv_matmul_kv_cache_bshd_continuous_batch[dtype: DType, target: StringSlice[StaticConstantOrigin] = "cpu"](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: ContinuousBatchingKVCacheCollection[dtype_, kv_params_], layer_idx: UInt32, valid_lengths: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. Only positions within valid\_lengths are written to the KV cache. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size, seq\_len, num\_heads \* head\_size). * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​kv\_collection ([`ContinuousBatchingKVCacheCollection`](/mojo/kernels/kv_cache/types/ContinuousBatchingKVCacheCollection)): The historical KVCache for keys and values. The KVCache for this layer is retrieved via layer\_idx. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The index of the layer being executed. Used to retrieve the KVCache for the given layer from kv\_collection. * ​valid\_lengths ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of shape \[batch] containing the valid length for each sequence. K and V are only written to cache for positions within these lengths. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The pre-allocated output buffer for Q projections. K and V projections are written in-place to k\_cache and v\_cache. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## generic_fused_qkv_matmul_kv_cache_bshd_paged
`generic_fused_qkv_matmul_kv_cache_bshd_paged[dtype: DType, target: StringSlice[StaticConstantOrigin] = "cpu"](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], layer_idx: UInt32, valid_lengths: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. Only positions within valid\_lengths are written to the KV cache. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size, seq\_len, num\_heads \* head\_size). * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The historical KVCache for keys and values. The KVCache for this layer is retrieved via layer\_idx. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The index of the layer being executed. Used to retrieve the KVCache for the given layer from kv\_collection. * ​valid\_lengths ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor of shape \[batch] containing the valid length for each sequence. K and V are only written to cache for positions within these lengths. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The pre-allocated output buffer for Q projections. K and V projections are written in-place to k\_cache and v\_cache. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## generic_get_continuous_cache
`generic_get_continuous_cache[dtype: DType, kv_params: KVCacheStaticParams](blocks: LayoutTensor[dtype, Layout.row_major[6](), origin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), origin], lookup_table: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), origin], max_lengths: LayoutTensor[DType.uint32, Layout.row_major[2](), origin]) -> ContinuousBatchingKVCacheCollection[dtype, kv_params]` **Returns:** [`ContinuousBatchingKVCacheCollection`](/mojo/kernels/kv_cache/types/ContinuousBatchingKVCacheCollection)
--- ## generic_get_paged_cache
`generic_get_paged_cache[dtype: DType](blocks: ManagedTensorSlice[MutableInput, static_spec=static_spec], cache_lengths: ManagedTensorSlice[Input, static_spec=static_spec], lookup_table: ManagedTensorSlice[Input, static_spec=static_spec], max_lengths: ManagedTensorSlice[Input, static_spec=static_spec], out result: PagedKVCacheCollection[dtype, KVCacheStaticParams(UInt(static_spec.shape.get[4]()), UInt(static_spec.shape.get[5]()), (static_spec.shape.get[1]() == 1)), static_spec.shape.get[3]()])` **Returns:** [`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection) `generic_get_paged_cache[dtype: DType, kv_params: KVCacheStaticParams, page_size: Int](blocks: LayoutTensor[dtype, Layout.row_major[6](), origin], cache_lengths: LayoutTensor[DType.uint32, Layout(IntTuple(-1)), origin], lookup_table: LayoutTensor[DType.uint32, Layout.row_major[2](), origin], max_lengths: LayoutTensor[DType.uint32, Layout.row_major[2](), origin], out result: PagedKVCacheCollection[dtype, kv_params, page_size])` **Returns:** [`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)
--- ## kv_cache (Kv_cache)
## `comptime` values ### `embed_fn_type` `comptime embed_fn_type = fn[dtype: DType, width: Int](IndexList[4], SIMD[dtype, width]) capturing -> SIMD[dtype, width]` ## Functions * [​`generic_flash_attention_kv_cache_padded`](./generic_flash_attention_kv_cache_padded): * [​`generic_flash_attention_kv_cache_padded_materialized_mask`](./generic_flash_attention_kv_cache_padded_materialized_mask): * [​`generic_fused_qk_rope_bshd_continuous_batch`](./generic_fused_qk_rope_bshd_continuous_batch): Performs a fused RoPE projection for Q and K projections. * [​`generic_fused_qk_rope_bshd_paged`](./generic_fused_qk_rope_bshd_paged): Performs a fused RoPE projection for Q and K with paged KV cache. * [​`generic_fused_qkv_matmul_kv_cache_bshd_continuous_batch`](./generic_fused_qkv_matmul_kv_cache_bshd_continuous_batch): Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. * [​`generic_fused_qkv_matmul_kv_cache_bshd_paged`](./generic_fused_qkv_matmul_kv_cache_bshd_paged): Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. * [​`generic_get_continuous_cache`](./generic_get_continuous_cache): * [​`generic_get_paged_cache`](./generic_get_paged_cache): * [​`print_kv_cache_cont_batch_generic_cpu`](./print_kv_cache_cont_batch_generic_cpu): * [​`print_kv_cache_cont_batch_generic_gpu`](./print_kv_cache_cont_batch_generic_gpu): * [​`print_kv_cache_paged_generic_cpu`](./print_kv_cache_paged_generic_cpu): * [​`print_kv_cache_paged_generic_gpu`](./print_kv_cache_paged_generic_gpu): * [​`rms_norm_kv_cache_ragged_continuous_batching`](./rms_norm_kv_cache_ragged_continuous_batching): Performs RMSNorm in place on new entries in the key cache. * [​`rms_norm_kv_cache_ragged_paged`](./rms_norm_kv_cache_ragged_paged): Performs RMSNorm in place on new entries in the key cache.
--- ## print_kv_cache_cont_batch_generic_cpu
`print_kv_cache_cont_batch_generic_cpu[target: StringSlice[StaticConstantOrigin], dtype: DType, kv_params: KVCacheStaticParams](valid_lengths: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: ContinuousBatchingKVCacheCollection[dtype, kv_params], layer_idx: UInt32, is_print_compact: Bool, context: DeviceContextPtr)`
--- ## print_kv_cache_cont_batch_generic_gpu
`print_kv_cache_cont_batch_generic_gpu[target: StringSlice[StaticConstantOrigin], dtype: DType, kv_params: KVCacheStaticParams](valid_lengths: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: ContinuousBatchingKVCacheCollection[dtype, kv_params], layer_idx: UInt32, is_print_compact: Bool, context: DeviceContextPtr)`
--- ## print_kv_cache_paged_generic_cpu
`print_kv_cache_paged_generic_cpu[target: StringSlice[StaticConstantOrigin], dtype: DType, kv_params: KVCacheStaticParams, page_size: Int](valid_lengths: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype, kv_params, page_size], layer_idx: UInt32, is_print_compact: Bool, context: DeviceContextPtr)`
--- ## print_kv_cache_paged_generic_gpu
`print_kv_cache_paged_generic_gpu[target: StringSlice[StaticConstantOrigin], dtype: DType, kv_params: KVCacheStaticParams, page_size: Int](valid_lengths: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype, kv_params, page_size], layer_idx: UInt32, is_print_compact: Bool, context: DeviceContextPtr)`
--- ## rms_norm_kv_cache_ragged_continuous_batching
`rms_norm_kv_cache_ragged_continuous_batching[dtype: DType, params: KVCacheStaticParams, //, target: StringSlice[StaticConstantOrigin], multiply_before_cast: Bool, per_head_norm: Bool](kv_collection: ContinuousBatchingKVCacheCollection[dtype, params], gamma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], layer_idx: UInt32, total_seq_len: UInt32, input_row_offsets: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr)` Performs RMSNorm in place on new entries in the key cache. This is done by first creating the ragged tensor weight\_shape (total\_seq\_len, num\_heads, head\_dim) of the new token tensor. To do this we need to pass in `total_seq_len` on host. Then, using `input_row_offsets` we find the corresponding batch and token index, and use that together with the static head and channel indices to store to/load from the key cache. This uses the input/output lambdas on the RMSNorm kernel. This function could apply RMSNorm to a subset of dimensions in each head, determined by the size of the gamma tensor. In this case, it operates on a ragged tensor view of the key cache with shape (total\_seq\_len, num\_heads, rms\_norm\_cols), where rms\_norm\_cols is the length of gamma and must be <= head\_size. `weight_offset` is a constant offset argument added to the learned weights at runtime. Here, we don't use any offset, so we pass in a zero scalar. `multiply_before_cast` is a boolean parameter that determines whether to multiply the normalized values by the gamma tensor before casting to the output dtype or not. We set it to `True` by default.
--- ## rms_norm_kv_cache_ragged_paged
`rms_norm_kv_cache_ragged_paged[dtype: DType, params: KVCacheStaticParams, page_size: Int, //, target: StringSlice[StaticConstantOrigin], multiply_before_cast: Bool, per_head_norm: Bool](kv_collection: PagedKVCacheCollection[dtype, params, page_size], gamma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], layer_idx: UInt32, total_seq_len: UInt32, input_row_offsets: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr)` Performs RMSNorm in place on new entries in the key cache. This is done by first creating the ragged tensor weight\_shape (total\_seq\_len, num\_heads, head\_dim) of the new token tensor. To do this we need to pass in `total_seq_len` on host. Then, using `input_row_offsets` we find the corresponding batch and token index, and use that together with the static head and channel indices to store to/load from the key cache. This uses the input/output lambdas on the RMSNorm kernel. This function could apply RMSNorm to a subset of dimensions in each head, determined by the size of the gamma tensor. In this case, it operates on a ragged tensor view of the key cache with shape (total\_seq\_len, num\_heads, rms\_norm\_cols), where rms\_norm\_cols is the length of gamma and must be <= head\_size. `weight_offset` is a constant offset argument added to the learned weights at runtime. Here, we don't use any offset, so we pass in a zero scalar. `multiply_before_cast` is a boolean parameter that determines whether to multiply the normalized values by the gamma tensor before casting to the output dtype or not. We set it to `True` by default.
--- ## generic_cross_attention_kv_cache
`generic_cross_attention_kv_cache[collection_t: KVCollectionT, dtype: DType, //, target: StringSlice[StaticConstantOrigin], mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_max_seq_len: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## generic_flare_mla_decode_kv_cache_ragged
`generic_flare_mla_decode_kv_cache_ragged[collection_t: KVCollectionT, dtype: DType, //, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr)`
--- ## generic_flare_mla_decompress_k_cache_ragged_paged
`generic_flare_mla_decompress_k_cache_ragged_paged[target: StringSlice[StaticConstantOrigin], dtype: DType](buffer_row_offsets_1d: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_offsets_1d: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], buffer_length: Int32, weight: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], layer_idx: UInt32, k_latent_buffer: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_buffer: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr)`
--- ## generic_flare_mla_prefill_kv_cache_ragged
`generic_flare_mla_prefill_kv_cache_ragged[collection_t: KVCollectionT, dtype: DType, //, softmax_type: DType, write_softmax_info: Bool, use_cascade_attention: Bool, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], buffer_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], softmax_info: LayoutTensor[softmax_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr, prev_output: OptionalReg[LayoutTensor[dtype, Layout.row_major[3](), MutAnyOrigin]] = None, prev_softmax_info: OptionalReg[LayoutTensor[softmax_type, Layout.row_major[3](), MutAnyOrigin]] = None)`
--- ## generic_flare_mla_prefill_ragged_paged_plan
`generic_flare_mla_prefill_ragged_paged_plan[target: StringSlice[StaticConstantOrigin]](input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], layer_idx: UInt32, buffer_token_size: UInt32, buffer_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], buffer_lengths: LayoutTensor[DType.int32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr)`
--- ## generic_flash_attention_kv_cache_ragged
`generic_flash_attention_kv_cache_ragged[collection_t: KVCollectionT, dtype: DType, //, *, target: StringSlice[StaticConstantOrigin], mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr)`
--- ## generic_flash_attention_kv_cache_ragged_sink
`generic_flash_attention_kv_cache_ragged_sink[collection_t: KVCollectionT, dtype: DType, //, *, target: StringSlice[StaticConstantOrigin], mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], local_window_size: Int = -1](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr, sink_weights: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## generic_fused_qk_rope_bshd_continuous_batch_ragged
`generic_fused_qk_rope_bshd_continuous_batch_ragged[dtype: DType, freq_dtype: DType, //, *, interleaved: Bool, has_position_ids: Bool, target: StringSlice[StaticConstantOrigin], mrope_section: Optional[IntTuple] = None](q_proj: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: ContinuousBatchingKVCacheCollection[dtype_, kv_params_], freqs_cis: LayoutTensor[freq_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], position_ids: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], layer_idx: UInt32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr)`
--- ## generic_fused_qk_rope_bshd_paged_ragged
`generic_fused_qk_rope_bshd_paged_ragged[dtype: DType, freq_dtype: DType, //, *, interleaved: Bool, has_position_ids: Bool, target: StringSlice[StaticConstantOrigin], mrope_section: Optional[IntTuple] = None](q_proj: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], freqs_cis: LayoutTensor[freq_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], position_ids: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], layer_idx: UInt32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr = DeviceContextPtr())` Performs a fused RoPE projection for Q and K projections. We have a manually fused QKV projection with mo.opaque dtypes in our Llama model. Due to a limitation in custom op definitions, we can't declare both a tensor and opaque dtype as output from a custom kernel. This requires us to only note Q\_proj as an output from the QKV projection. If we immediately follow the QKV proj kernel with a RoPE kernel applied to K, we'll get a race condition because the graph compiler doesn't know about the dependency between these kernels in the graph definition. Here we fuse the RoPE kernel applied to Q\_proj with K\_proj, so K\_proj RoPE is only executed after QKV completes.
--- ## generic_fused_qkv_matmul_kv_cache_cont_batch_ragged
`generic_fused_qkv_matmul_kv_cache_cont_batch_ragged[dtype: DType, //, target: StringSlice[StaticConstantOrigin] = "cpu"](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: ContinuousBatchingKVCacheCollection[dtype_, kv_params_], layer_idx: UInt32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_heads \* head\_size). * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size + 1,). The value at each index is the start\_idx of the corresponding batch in hidden\_state. * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​kv\_collection ([`ContinuousBatchingKVCacheCollection`](/mojo/kernels/kv_cache/types/ContinuousBatchingKVCacheCollection)): The object storing the KVCache for this layer. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The current layer, used to retrieve the KVCache object from kv\_collection. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The pre-allocated output buffer for Q projections. K and V projections are written in-place to k\_cache and v\_cache. Shape: (sum(seq\_lens), num\_heads \* head\_size). * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## generic_fused_qkv_matmul_kv_cache_paged_ragged
`generic_fused_qkv_matmul_kv_cache_paged_ragged[dtype: DType, weight_dtype: DType, target: StringSlice[StaticConstantOrigin] = "cpu", group_size: OptionalReg[Int] = None, has_zp: OptionalReg[Bool] = None](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[weight_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], layer_idx: UInt32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_heads \* head\_size). * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size + 1,). The value at each index is the start\_idx of the corresponding batch in hidden\_state. * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The object storing the KVCache for this layer. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The current layer, used to retrieve the KVCache object from kv\_collection. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The pre-allocated output buffer for Q projections. K and V projections are written in-place to k\_cache and v\_cache. Shape: (sum(seq\_lens), num\_heads \* head\_size). * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## generic_fused_qkv_matmul_kv_cache_paged_ragged_bias
`generic_fused_qkv_matmul_kv_cache_paged_ragged_bias[dtype: DType, weight_dtype: DType, target: StringSlice[StaticConstantOrigin] = "cpu", group_size: OptionalReg[Int] = None, has_zp: OptionalReg[Bool] = None](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[weight_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], layer_idx: UInt32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], bias: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_heads \* head\_size). * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size + 1,). The value at each index is the start\_idx of the corresponding batch in hidden\_state. * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The object storing the KVCache for this layer. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The current layer, used to retrieve the KVCache object from kv\_collection. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The pre-allocated output buffer for Q projections. K and V projections are written in-place to k\_cache and v\_cache. Shape: (sum(seq\_lens), num\_heads \* head\_size). * ​bias ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Bias to be added to the QKV Tensor. Tensor is concatenated q + k + v. Rank 1. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## generic_fused_qkv_matmul_kv_cache_paged_ragged_scale
`generic_fused_qkv_matmul_kv_cache_paged_ragged_scale[dtype: DType, weight_dtype: DType, output_dtype: DType, scale_dtype: DType, target: StringSlice[StaticConstantOrigin] = "cpu"](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[weight_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_scale: LayoutTensor[scale_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight_scale: LayoutTensor[scale_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], layer_idx: UInt32, output: LayoutTensor[output_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr, bias: OptionalReg[LayoutTensor[output_dtype, Layout.row_major(-1), ImmutAnyOrigin]] = None)` Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_heads \* head\_size). * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size + 1,). The value at each index is the start\_idx of the corresponding batch in hidden\_state. * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​input\_scale ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scale to be multiplied to the input Tensor. * ​weight\_scale ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scale to be multiplied to the weight Tensor. * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The object storing the KVCache for this layer. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The current layer, used to retrieve the KVCache object from kv\_collection. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The pre-allocated output buffer for Q projections. K and V projections are written in-place to k\_cache and v\_cache. Shape: (sum(seq\_lens), num\_heads \* head\_size). * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler. * ​bias ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional bias vector concatenated as \[q, k, v].
--- ## generic_kv_cache_radd_dispatch
`generic_kv_cache_radd_dispatch[dtype: DType, collection_t: KVCollectionT, //, target: StringSlice[StaticConstantOrigin]](a: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache: collection_t, input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], batch_offset: UInt32, layer_idx: UInt32, ctx: Optional[DeviceContext])`
--- ## kv_cache_ragged
## Functions * [​`generic_cross_attention_kv_cache`](./generic_cross_attention_kv_cache): * [​`generic_flare_mla_decode_kv_cache_ragged`](./generic_flare_mla_decode_kv_cache_ragged): * [​`generic_flare_mla_decompress_k_cache_ragged_paged`](./generic_flare_mla_decompress_k_cache_ragged_paged): * [​`generic_flare_mla_prefill_kv_cache_ragged`](./generic_flare_mla_prefill_kv_cache_ragged): * [​`generic_flare_mla_prefill_ragged_paged_plan`](./generic_flare_mla_prefill_ragged_paged_plan): * [​`generic_flash_attention_kv_cache_ragged`](./generic_flash_attention_kv_cache_ragged): * [​`generic_flash_attention_kv_cache_ragged_sink`](./generic_flash_attention_kv_cache_ragged_sink): * [​`generic_fused_qk_rope_bshd_continuous_batch_ragged`](./generic_fused_qk_rope_bshd_continuous_batch_ragged): * [​`generic_fused_qk_rope_bshd_paged_ragged`](./generic_fused_qk_rope_bshd_paged_ragged): Performs a fused RoPE projection for Q and K projections. * [​`generic_fused_qkv_matmul_kv_cache_cont_batch_ragged`](./generic_fused_qkv_matmul_kv_cache_cont_batch_ragged): Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. * [​`generic_fused_qkv_matmul_kv_cache_paged_ragged`](./generic_fused_qkv_matmul_kv_cache_paged_ragged): Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. * [​`generic_fused_qkv_matmul_kv_cache_paged_ragged_bias`](./generic_fused_qkv_matmul_kv_cache_paged_ragged_bias): Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. * [​`generic_fused_qkv_matmul_kv_cache_paged_ragged_scale`](./generic_fused_qkv_matmul_kv_cache_paged_ragged_scale): Performs a fused QKV matmul. Q outputs are written to the output argument while K and V outputs are written in-place into k\_cache and v\_cache. * [​`generic_kv_cache_radd_dispatch`](./generic_kv_cache_radd_dispatch): * [​`k_matmul_ragged_paged`](./k_matmul_ragged_paged): Performs a matmul, writing the output into a mutable PagedKVCacheCollection object. * [​`k_matmul_ragged_paged_scale`](./k_matmul_ragged_paged_scale): Performs a matmul, writing the output into a mutable PagedKVCacheCollection object. * [​`kv_cache_2m_iadd_dispatch`](./kv_cache_2m_iadd_dispatch): In-place add to paged KV cache with concatenated K/V layout. This kernel is only used for LoRA. * [​`kv_cache_store_ragged`](./kv_cache_store_ragged): * [​`kv_matmul_ragged_paged`](./kv_matmul_ragged_paged): Performs a matmul, writing the output into a mutable ContinuousBatchingKVCacheCollection object. * [​`unfused_qkv_matmul_ragged_paged_gguf_quantized`](./unfused_qkv_matmul_ragged_paged_gguf_quantized): Performs a quantized matmul, writing the output into a mutable PagedKVCacheCollection object.
--- ## k_matmul_ragged_paged
`k_matmul_ragged_paged[dtype: DType, params: KVCacheStaticParams, page_size: Int, //, target: StringSlice[StaticConstantOrigin]](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype, params, page_size], layer_idx: UInt32, ctx: DeviceContextPtr)` Performs a matmul, writing the output into a mutable PagedKVCacheCollection object. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_heads \* head\_size). * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size + 1,) denoting the start of each sequence along the seq\_len dimension. * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The historical KVCache for keys and values. The KVCache for this layer is retrieved via layer\_idx. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The index of the layer being executed. Used to retrieve the KVCache for the given layer from kv\_collection. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## k_matmul_ragged_paged_scale
`k_matmul_ragged_paged_scale[dtype: DType, weight_dtype: DType, scale_dtype: DType, target: StringSlice[StaticConstantOrigin], scales_granularity_mnk: IndexList[3]](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[weight_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_scale: LayoutTensor[scale_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight_scale: LayoutTensor[scale_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype_, kv_params_, page_size], layer_idx: UInt32, ctx: DeviceContextPtr)` Performs a matmul, writing the output into a mutable PagedKVCacheCollection object. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_heads \* head\_size). * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size + 1,) denoting the start of each sequence along the seq\_len dimension. * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​input\_scale ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scale to be multiplied to the input Tensor. * ​weight\_scale ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scale to be multiplied to the weight Tensor. * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The historical KVCache for keys and values. The KVCache for this layer is retrieved via layer\_idx. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The index of the layer being executed. Used to retrieve the KVCache for the given layer from kv\_collection. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## kv_cache_2m_iadd_dispatch
`kv_cache_2m_iadd_dispatch[dtype: DType, collection_t: KVCollectionT, //, target: StringSlice[StaticConstantOrigin]](kv: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache: collection_t, input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], lora_end_idx: LayoutTensor[DType.int64, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], batch_seq_len: LayoutTensor[DType.int64, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], layer_idx: UInt32, ctx: Optional[DeviceContext])` In-place add to paged KV cache with concatenated K/V layout. This kernel is only used for LoRA. Performs an in-place addition of new key-value projections to paged KV cache. The input tensor `a` uses a "2m" layout where keys and values are concatenated: rows \[0, m) contain keys and rows \[m, 2m) contain values, where m is the number of tokens. We use the `lora_end_idx` to index into the K or V tensor. We call this value `m` since this value will be a subset of the total tokens in the batch. We write tokens to K as \[0, m) and V as \[m, 2m).
--- ## kv_cache_store_ragged
`kv_cache_store_ragged[cache_t: KVCacheT, input_row_offsets_layout: Layout, //, target: StringSlice[StaticConstantOrigin], input_fn: fn[width: Int, alignment: Int](idx: IndexList[3]) capturing -> SIMD[cache_t.dtype, width]](cache: cache_t, input_shape: IndexList[3], input_row_offsets: LayoutTensor[DType.uint32, input_row_offsets_layout, MutAnyOrigin], context: Optional[DeviceContext])`
--- ## kv_matmul_ragged_paged
`kv_matmul_ragged_paged[dtype: DType, params: KVCacheStaticParams, page_size: Int, //, target: StringSlice[StaticConstantOrigin]](hidden_state: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], weight: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype, params, page_size], layer_idx: UInt32, ctx: DeviceContextPtr)` Performs a matmul, writing the output into a mutable ContinuousBatchingKVCacheCollection object. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_heads \* head\_size). * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size + 1,) denoting the start of each sequence along the seq\_len dimension. * ​weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The historical KVCache for keys and values. The KVCache for this layer is retrieved via layer\_idx. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The index of the layer being executed. Used to retrieve the KVCache for the given layer from kv\_collection. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## unfused_qkv_matmul_ragged_paged_gguf_quantized
`unfused_qkv_matmul_ragged_paged_gguf_quantized[dtype: DType, params: KVCacheStaticParams, page_size: Int, //, quantization_encoding_q: StringSlice[StaticConstantOrigin], quantization_encoding_k: StringSlice[StaticConstantOrigin], quantization_encoding_v: StringSlice[StaticConstantOrigin]](hidden_state: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_weight: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_weight: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v_weight: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: PagedKVCacheCollection[dtype, params, page_size], layer_idx: UInt32, output: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Performs a quantized matmul, writing the output into a mutable PagedKVCacheCollection object. Unlike the un-quantized version (kv\_matmul\_ragged\_continuous\_batching), this implementation does not concat the q, k, and v weights together. Instead, it performs three matmuls. This allows the q, k, and v weights to have different quantization encodings. This is only supported on CPU. **Args:** * ​hidden\_state ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_heads \* head\_size). * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (batch\_size + 1,) denoting the start of each sequence along the seq\_len dimension. * ​q\_weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​k\_weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​v\_weight ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (num\_heads \* head\_size, num\_kv\_heads \* head\_size). * ​kv\_collection ([`PagedKVCacheCollection`](/mojo/kernels/kv_cache/types/PagedKVCacheCollection)): The Collection object storing KVCache entries. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The index of the layer being executed. Used to retrieve the KVCache for the given layer from kv\_collection. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Tensor with shape (sum(seq\_lens), num\_kv\_heads \* head\_size). This is the output buffer for the Q matmul. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The call context pointer, passed by the graph compiler.
--- ## depth_supported_by_gpu
`depth_supported_by_gpu[depth: Int, mask_t: MHAMask, config: MHAConfig[dtype], info: GPUInfo]() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## flash_attention (Mha)
`flash_attention[dtype: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig[dtype] = MHAConfig[dtype](UInt(Int.__init__[IntTuple](q_layout.shape[2])), UInt(Int.__init__[IntTuple](q_layout.shape[3])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, context: DeviceContextPtr = DeviceContextPtr(), num_partitions: OptionalReg[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)` `flash_attention[cache_t: KVCacheT, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig[dtype] = MHAConfig[dtype](UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), ragged: Bool = False, sink: Bool = False, decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, v: cache_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None, num_partitions: OptionalReg[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)` Flash attention 2 algorithm. Compute: (1) Transpose (Q) BSHD -> BHSD; (2) Transpose (K) BSHD -> BHSD; (3) Transpose (V) BSHD -> BHSD; (4) P = Bmm(Q, K), P is also called "score"; (5) P = P \* scale + mask; (6) P = softmax(P); (7) O = Bmm(P, V) (8) Output = Transpose(O). B, S, H, D denote batch size, sequence length, head count and depth, respectively. (1), (2), (3) happens while loading the data into shared memory. (8) happens when writing output to global memory. All inputs (query, key, and value) must have BSHD layout. The mask can be BSS or BHSS. This kernel also handles grouped attention optimization. In this case the shape of K and V are BShD where h = H / num\_groups. This kernels handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid\_length argument. `flash_attention[mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig[dtype] = MHAConfig[dtype](UInt(Int.__init__[IntTuple](q_layout.shape[2])), UInt(Int.__init__[IntTuple](q_layout.shape[3])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, _use_valid_length: Bool = False, _padded_ndbuffer: Bool = False, naive_kernel: Bool = False, sink: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask_functor: mask_t, score_mod_functor: score_mod_t, scale: Float32, ctx: DeviceContext, num_partitions: OptionalReg[Int] = None, valid_length: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## flash_attention_dispatch
`flash_attention_dispatch[k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, q_layout: Layout, //, kv_num_heads: Int, use_score_mod: Bool = False, config: MHAConfig[dtype] = MHAConfig[dtype](UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), ragged: Bool = False, sink: Bool = False, _is_flash_attention_applicable: Bool = True, _is_cache_length_accurate: Bool = False, _use_valid_length: Bool = True, _padded_ndbuffer: Bool = False, decoding_warp_split_k: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: k_t, v: v_t, mask_functor: mask_t, score_mod_functor: score_mod_t, max_prompt_len: Int, max_cache_valid_length: Int, scale: Float32, is_token_generation: Bool, ctx: DeviceContext, valid_length: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None, num_partitions: OptionalReg[Int] = None, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## flash_attention_hw_supported
`flash_attention_hw_supported[qkv_type: DType]() -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## flash_attention_ragged
`flash_attention_ragged[mask_t: MHAMask, score_mod_t: ScoreModTrait, type: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig[type] = MHAConfig[type](UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False, naive_kernel: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[type, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin], max_prompt_len: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask_functor: mask_t, score_mod_functor: score_mod_t, scale: Float32, ctx: DeviceContext, num_partitions: OptionalReg[Int] = None)`
--- ## get_mha_decoding_num_partitions
`get_mha_decoding_num_partitions[num_heads: Int, group: Int](batch_size: Int, num_keys: Int, ctx: DeviceContext) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## mha
## Functions * [​`depth_supported_by_gpu`](./depth_supported_by_gpu): * [​`flash_attention`](./flash_attention): * [​`flash_attention_dispatch`](./flash_attention_dispatch): * [​`flash_attention_hw_supported`](./flash_attention_hw_supported): * [​`flash_attention_ragged`](./flash_attention_ragged): * [​`get_mha_decoding_num_partitions`](./get_mha_decoding_num_partitions): * [​`mha`](./mha): * [​`mha_decoding`](./mha_decoding): * [​`mha_decoding_single_batch`](./mha_decoding_single_batch): Flash attention v2 algorithm. * [​`mha_decoding_single_batch_pipelined`](./mha_decoding_single_batch_pipelined): Flash attention v2 algorithm. * [​`mha_gpu_naive`](./mha_gpu_naive): * [​`mha_single_batch`](./mha_single_batch): MHA for token gen where seqlen = 1 and num\_keys >= 1. * [​`mha_single_batch_pipelined`](./mha_single_batch_pipelined): MHA for token gen where seqlen = 1 and num\_keys >= 1. * [​`mha_splitk_reduce`](./mha_splitk_reduce): * [​`q_num_matrix_view_rows`](./q_num_matrix_view_rows): * [​`scale_and_mask_helper`](./scale_and_mask_helper):
--- ## mha (Mha)
`mha[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, valid_length_layout: Layout, config: MHAConfig[dtype], group: Int = 1, use_score_mod: Bool = False, ragged: Bool = False, is_shared_kv: Bool = False, sink: Bool = False, _use_valid_length: Bool = False, _is_cache_length_accurate: Bool = False, _padded_ndbuffer: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], scale: Float32, batch_size: Int, seq_len_arg: Int, num_keys_arg: Int, valid_length: LayoutTensor[DType.uint32, valid_length_layout, MutAnyOrigin], kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]], sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]], mask: mask_t, score_mod: score_mod_t)`
--- ## mha_decoding
`mha_decoding[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, valid_length_layout: Layout, BM: UInt, BN: UInt, BK: UInt, WM: UInt, WN: UInt, depth: UInt, num_heads: UInt, num_threads: UInt, num_pipeline_stages: UInt, group: UInt = 1, use_score_mod: Bool = False, ragged: Bool = False, is_shared_kv: Bool = False, sink: Bool = False, _use_valid_length: Bool = False, _is_cache_length_accurate: Bool = False, decoding_warp_split_k: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], scale: Float32, batch_size: Int, num_partitions: Int, max_cache_valid_length: Int, valid_length: LayoutTensor[DType.uint32, valid_length_layout, MutAnyOrigin], sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]], mask: mask_t, score_mod: score_mod_t)`
--- ## mha_decoding_single_batch
`mha_decoding_single_batch[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, BM: UInt, BN: UInt, BK: UInt, WM: UInt, WN: UInt, depth: UInt, num_heads: UInt, num_threads: UInt, num_pipeline_stages: UInt, group: UInt = 1, use_score_mod: Bool = False, decoding_warp_split_k: Bool = False, sink: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], scale: Float32, num_keys: UInt, num_partitions: UInt, max_cache_valid_length: UInt, mask: mask_t, score_mod: score_mod_t, batch_idx: Int, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]])` Flash attention v2 algorithm.
--- ## mha_decoding_single_batch_pipelined
`mha_decoding_single_batch_pipelined[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, BM: UInt, BN: UInt, BK: UInt, WM: UInt, WN: UInt, depth: UInt, num_heads: UInt, num_threads: UInt, num_pipeline_stages: UInt, group: UInt = 1, use_score_mod: Bool = False, decoding_warp_split_k: Bool = False, sink: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], scale: Float32, num_keys: UInt, num_partitions: UInt, max_cache_valid_length: UInt, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]], mask: mask_t, score_mod: score_mod_t, batch_idx: Int)` Flash attention v2 algorithm.
--- ## mha_gpu_naive
`mha_gpu_naive[output_type: DType, k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, //, ragged: Bool = False, sink: Bool = False, _use_valid_length: Bool = False, _is_cache_length_accurate: Bool = False](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: k_t, v: v_t, mask_functor: mask_t, output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, batch_size: Int, max_prompt_len: Int, max_cache_size: Int, num_heads: Int, depth: Int, group: Int, ctx: DeviceContext, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutAnyOrigin]] = None)` `mha_gpu_naive[q_type: DType, k_type: DType, v_type: DType, output_type: DType, mask_type: DType, //, sink: Bool = False](q: LayoutTensor[q_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[k_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v: LayoutTensor[v_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask: LayoutTensor[mask_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, batch_size: Int, seq_len: Int, num_keys: Int, num_heads: Int, depth: Int, group: Int, ctx: DeviceContext, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]] = None)` `mha_gpu_naive[q_type: DType, output_type: DType, cache_t: KVCacheT, mask_t: MHAMask, //, ragged: Bool = False, sink: Bool = False](q: LayoutTensor[q_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, v: cache_t, mask_functor: mask_t, output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, batch_size: Int, max_prompt_len: Int, max_cache_size: Int, num_heads: Int, depth: Int, group: Int, ctx: DeviceContext, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## mha_single_batch
`mha_single_batch[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, config: MHAConfig[dtype], group: Int = 1, use_score_mod: Bool = False, sink: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], scale: Float32, seq_len: Int, max_seq_len: Int, start_pos: UInt32, num_keys: Int, mask_tensor_col: Int, mask: mask_t, score_mod: score_mod_t, batch_idx: Int, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]])` MHA for token gen where seqlen = 1 and num\_keys >= 1. The general data layout and steps conform to flash attention. Two exceptions: 1 Partition across B, H, and num\_keys (TODO). The last one is split-K and will need a separate reduction kernel at the end. 2 First bmm becomes gemv and second bmm becomes gevm. TODO: use more optimized kernels for them
--- ## mha_single_batch_pipelined
`mha_single_batch_pipelined[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, config: MHAConfig[dtype], group: Int = 1, use_score_mod: Bool = False, sink: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], scale: Float32, seq_len: Int, max_seq_len: Int, start_pos: UInt32, num_keys: Int, mask_tensor_col: Int, mask: mask_t, score_mod: score_mod_t, batch_idx: Int, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]])` MHA for token gen where seqlen = 1 and num\_keys >= 1. The general data layout and steps conform to flash attention. Two exceptions: 1 Partition across B, H, and num\_keys (TODO). The last one is split-K and will need a separate reduction kernel at the end. 2 First bmm becomes gemv and second bmm becomes gevm. TODO: use more optimized kernels for them
--- ## mha_splitk_reduce
`mha_splitk_reduce[output_type: DType, depth: UInt, num_heads: UInt, num_threads: UInt, group: UInt = 1, use_exp2: Bool = False](intermediate_ptr: LegacyUnsafePointer[Scalar[output_type]], output_ptr: LegacyUnsafePointer[Scalar[output_type]], exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[output_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[output_type]()]], batch_size: Int, num_partitions: Int)`
--- ## q_num_matrix_view_rows
`q_num_matrix_view_rows[dtype: DType, //](q: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## scale_and_mask_helper
`scale_and_mask_helper[p_type: DType, p_layout: Layout, mask_t: MHAMask, score_mod_t: ScoreModTrait, group: Int, num_n_mmas: Int, WN: Int, MMA_N: Int, simd_width: Int, use_score_mod: Bool = False](p_reg_tile: LayoutTensor[p_type, p_layout, origin, address_space=AddressSpace.LOCAL], scale_log2e: Float32, num_keys: UInt, bound: UInt, lane: UInt, warp: UInt, mask: mask_t, score_mod: score_mod_t, kv_tile_start_row: Int, mask_stride: UInt, max_seq_len: Int)`
--- ## mha_cross
## Functions * [​`mha_cross_gpu_naive`](./mha_cross_gpu_naive): Naive cross attention on GPU.
--- ## mha_cross_gpu_naive
`mha_cross_gpu_naive[cache_t: KVCacheT, mask_t: MHAMask, dtype: DType, //, rank: Int](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_input_row_offsets: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_max_seq_len: Int, k: cache_t, v: cache_t, kv_input_row_offsets: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask_functor: mask_t, scale: Float32, ctx: DeviceContext)` Naive cross attention on GPU. Note that this assumes ragged tensor inputs and uses a mask functor. Computes: (1) Transpose (Q) BSHD -> BHSD; (2) Transpose (K) BSHD -> BHSD; (3) Transpose (V) BSHD -> BHSD; (4) P = Bmm(Q, K), P is also called "score"; (5) P = P \* scale + mask; (6) P = softmax(P); (7) O = Bmm(P, V) (8) Output = Transpose(O). B, S, H, D denote batch size, sequence length, head count and depth, respectively. (1), (2), (3) happens while loading the data into shared memory. (8) happens when writing output to global memory. All inputs (query, key, and value) must have BSHD layout. The mask can be BSS or BHSS. This kernel also handles grouped attention optimization. In this case the shape of K and V are BShD where h = H / num\_groups.
--- ## MHAPosition
`@register_passable(trivial)` `struct MHAPosition[BM: Int, BN: Int, depth: Int, padded_depth: Int, q_num_heads: Int, group: Int, decoding: Bool]` Position of the MHA-kernel. When `decoding=False`, `q_head_stride == q_num_heads`. When `decoding=True`, `q_head_stride == 1`. ## Fields * ​q\_row (`UInt32`): * ​q\_col (`UInt32`): * ​q\_out\_offset (`Int`): * ​num\_keys (`UInt32`): * ​start\_pos (`UInt32`): * ​seq\_len (`UInt32`): * ​head\_idx (`UInt32`): * ​prompt\_offset (`UInt32`): * ​prompt\_idx (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_q_heads_per_thread` `comptime num_q_heads_per_thread = min(2, ceildiv(group, 8)) if decoding else 1` ### `q_output_gmem_layout` `comptime q_output_gmem_layout = Layout(IntTuple(BM, depth), IntTuple(MHAPosition[BM, BN, depth, padded_depth, q_num_heads, group, decoding].q_stride, 1))` ### `q_stride` `comptime q_stride = depth if decoding else (depth * q_num_heads)` ### `split_gmem_layout` `comptime split_gmem_layout = Layout(IntTuple((BM // 2), depth), IntTuple(MHAPosition[BM, BN, depth, padded_depth, q_num_heads, group, decoding].q_stride, 1))` ## Methods ### `__init__` `__init__(q_row: UInt32, q_col: UInt32, q_out_offset: Int, num_keys: UInt32, start_pos: UInt32, seq_info: SeqInfo) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `q_head_idx` `q_head_idx(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `kv_head_idx` `kv_head_idx(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `write_to` `write_to(self, mut writer: T)` ### `q_tile_num_rows` `q_tile_num_rows(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `q_out_gmem_tensor` `q_out_gmem_tensor[dtype: DType](self, ptr: LegacyUnsafePointer[Scalar[dtype]]) -> LayoutTensor[dtype, MHAPosition[BM, BN, depth, padded_depth, q_num_heads, group, decoding].q_output_gmem_layout, MutAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `mask_status` `mask_status[MaskType: MHAMask](self, mask: MaskType, kv_tile_start_row: UInt32) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `get_score_row` `get_score_row(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `exp_sum_qk_max_ptr` `exp_sum_qk_max_ptr[partition_t: MHAPartitionScheme](self, partition: partition_t, batch_size: UInt32) -> Tuple[LegacyUnsafePointer[Scalar[partition_t.accum_dtype]], LegacyUnsafePointer[Scalar[partition_t.accum_dtype]]]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `get_start_and_end_for_partitions` `get_start_and_end_for_partitions[partition_t: MHAPartitionScheme, //](self, partition: partition_t) -> Tuple[UInt32, UInt32]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `get_q_gmem_row` `static get_q_gmem_row[MaxSeqLenType: OptionallyStaticInt, //, ragged: Bool](seq_info: SeqInfo, max_seq_len: MaxSeqLenType) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) `static get_q_gmem_row[ragged: Bool](seq_info: SeqInfo, max_seq_len: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## NonNullPointer
`@register_passable(trivial)` `struct NonNullPointer[dtype_: DType]` ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[NonNullPointer[dtype_].dtype]]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`OptionalPointer`](/mojo/kernels/nn/mha_fa3_utils/OptionalPointer), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `dtype` `comptime dtype = dtype_` ### `is_null` `comptime is_null = False` ## Methods ### `__init__` `__init__(ptr: LegacyUnsafePointer[Scalar[NonNullPointer[dtype_].dtype]]) -> Self` `__init__(ptr: DeviceBuffer[NonNullPointer[dtype_].dtype]) -> Self` ### `value` `value(self) -> LegacyUnsafePointer[Scalar[NonNullPointer[dtype_].dtype]]` **Returns:** `LegacyUnsafePointer`
--- ## NullPointer
`@register_passable(trivial)` `struct NullPointer[dtype_: DType]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`OptionalPointer`](/mojo/kernels/nn/mha_fa3_utils/OptionalPointer), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `dtype` `comptime dtype = dtype_` ### `is_null` `comptime is_null = True` ## Methods ### `__init__` `__init__() -> Self` ### `value` `value(self) -> LegacyUnsafePointer[Scalar[NullPointer[dtype_].dtype]]` **Returns:** `LegacyUnsafePointer`
--- ## OptionalPointer
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `dtype` `comptime dtype` ### `is_null` `comptime is_null` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `value` `value(self: _Self) -> LegacyUnsafePointer[Scalar[_Self.dtype]]` **Returns:** `LegacyUnsafePointer` ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## Pack
`@register_passable(trivial)` `struct Pack[MaskType: MHAMask, ScoreModType: ScoreModTrait, SchedulerType: MHATileScheduler, ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme]` ## Fields * ​mask (`MaskType`): * ​score\_mod (`ScoreModType`): * ​scheduler (`SchedulerType`): * ​valid\_length (`ValidLengthType`): * ​sink\_weights (`SinkType`): * ​kv\_input\_row\_offsets (`KVRowOffsetsType`): * ​max\_seq\_len (`MaxSeqLenType`): * ​partition (`PartitionType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = PartitionType.__copyinit__is_trivial if MaxSeqLenType.__copyinit__is_trivial if KVRowOffsetsType.__copyinit__is_trivial if SinkType.__copyinit__is_trivial if ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SinkType.__copyinit__is_trivial if ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else KVRowOffsetsType.__copyinit__is_trivial if SinkType.__copyinit__is_trivial if ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SinkType.__copyinit__is_trivial if ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else MaxSeqLenType.__copyinit__is_trivial if KVRowOffsetsType.__copyinit__is_trivial if SinkType.__copyinit__is_trivial if ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SinkType.__copyinit__is_trivial if ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else KVRowOffsetsType.__copyinit__is_trivial if SinkType.__copyinit__is_trivial if ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SinkType.__copyinit__is_trivial if ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial if SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else SchedulerType.__copyinit__is_trivial if ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial else ScoreModType.__copyinit__is_trivial if MaskType.__copyinit__is_trivial else MaskType.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = PartitionType.__del__is_trivial if MaxSeqLenType.__del__is_trivial if KVRowOffsetsType.__del__is_trivial if SinkType.__del__is_trivial if ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SinkType.__del__is_trivial if ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else KVRowOffsetsType.__del__is_trivial if SinkType.__del__is_trivial if ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SinkType.__del__is_trivial if ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else MaxSeqLenType.__del__is_trivial if KVRowOffsetsType.__del__is_trivial if SinkType.__del__is_trivial if ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SinkType.__del__is_trivial if ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else KVRowOffsetsType.__del__is_trivial if SinkType.__del__is_trivial if ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SinkType.__del__is_trivial if ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ValidLengthType.__del__is_trivial if SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else SchedulerType.__del__is_trivial if ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial else ScoreModType.__del__is_trivial if MaskType.__del__is_trivial else MaskType.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = PartitionType.__moveinit__is_trivial if MaxSeqLenType.__moveinit__is_trivial if KVRowOffsetsType.__moveinit__is_trivial if SinkType.__moveinit__is_trivial if ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SinkType.__moveinit__is_trivial if ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else KVRowOffsetsType.__moveinit__is_trivial if SinkType.__moveinit__is_trivial if ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SinkType.__moveinit__is_trivial if ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else MaxSeqLenType.__moveinit__is_trivial if KVRowOffsetsType.__moveinit__is_trivial if SinkType.__moveinit__is_trivial if ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SinkType.__moveinit__is_trivial if ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else KVRowOffsetsType.__moveinit__is_trivial if SinkType.__moveinit__is_trivial if ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SinkType.__moveinit__is_trivial if ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial if SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else SchedulerType.__moveinit__is_trivial if ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial else ScoreModType.__moveinit__is_trivial if MaskType.__moveinit__is_trivial else MaskType.__moveinit__is_trivial` ### `device_type` `comptime device_type = Pack[MaskType, ScoreModType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType]` ## Methods ### `__init__` `__init__(mask: MaskType, score_mod: ScoreModType, scheduler: SchedulerType, valid_length: ValidLengthType, sink_weights: SinkType, kv_input_row_offsets: KVRowOffsetsType, max_seq_len: MaxSeqLenType, partition: PartitionType) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String)
--- ## PositionSummary
`@register_passable(trivial)` `struct PositionSummary` ## Fields * ​num\_keys (`UInt32`): * ​score\_row (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(num_keys: UInt32, score_row: UInt32) -> Self` ### `get_start_pos` `static get_start_pos[KVLUTType: MHAOperand, //, ragged: Bool, _is_cache_length_accurate: Bool](kv_lut: KVLUTType, seq_info: SeqInfo, num_keys_arg: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_num_keys` `static get_num_keys[MaxSeqLenType: OptionallyStaticInt, KVInputRowOffsetsType: OptionalPointer, //, ragged: Bool, _is_cache_length_accurate: Bool](kv_input_row_offsets: KVInputRowOffsetsType, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, num_keys_arg: UInt32, start_pos: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_score_row` `static get_score_row[*, ragged: Bool, _is_cache_length_accurate: Bool, decoding: Bool](seq_info: SeqInfo, num_keys: UInt32, start_pos: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `create` `static create[KVLUTType: MHAOperand, KVRowOffsetsType: OptionalPointer, MaxSeqLenType: OptionallyStaticInt, //, ragged: Bool, _is_cache_length_accurate: Bool](kv_lut: KVLUTType, seq_info: SeqInfo, num_keys_arg: UInt32, kv_input_row_offsets: KVRowOffsetsType, max_seq_len: MaxSeqLenType) -> Self`
--- ## get_q_head_idx
`get_q_head_idx[BM: Int, BN: Int, depth: Int, padded_depth: Int, num_heads: Int, group: Int, decoding: Bool, //](position: MHAPosition[BM, BN, depth, padded_depth, num_heads, group, decoding], lane: UInt32) -> StaticTuple[UInt32, MHAPosition[BM, BN, depth, padded_depth, num_heads, group, decoding].num_q_heads_per_thread]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## get_seq_info
`get_seq_info[MaxSeqLenType: OptionallyStaticInt, ValidLengthType: OptionalPointer, PartitionType: MHAPartitionScheme, //, BM: Int, num_heads: Int](batch_size: UInt32, max_seq_len: MaxSeqLenType, valid_length: ValidLengthType, partition: PartitionType) -> SeqInfo` **Returns:** `SeqInfo`
--- ## mha_fa3_utils
## `comptime` values ### `KVTMATile` `comptime KVTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BN: Int, depth: Int, BK: Int = depth] = TMATensorTile[dtype, _split_last_layout[dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)]` #### Parameters * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​swizzle\_mode (`TensorMapSwizzle`): * ​BN ([`Int`](/stdlib/builtin/int/Int)): * ​depth ([`Int`](/stdlib/builtin/int/Int)): * ​BK ([`Int`](/stdlib/builtin/int/Int)): ### `QTMATile` `comptime QTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int, group: Int, decoding: Bool] = TMATensorTile[dtype, _split_last_layout[dtype](q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding](), swizzle_mode, True), _ragged_desc_layout[dtype](q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding](), swizzle_mode)]` #### Parameters * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​swizzle\_mode (`TensorMapSwizzle`): * ​BM ([`Int`](/stdlib/builtin/int/Int)): * ​depth ([`Int`](/stdlib/builtin/int/Int)): * ​group ([`Int`](/stdlib/builtin/int/Int)): * ​decoding ([`Bool`](/stdlib/builtin/bool/Bool)): ## Structs * [​`MHAPosition`](./MHAPosition): Position of the MHA-kernel. When `decoding=False`, `q_head_stride == q_num_heads`. When `decoding=True`, `q_head_stride == 1`. * [​`NonNullPointer`](./NonNullPointer): * [​`NullPointer`](./NullPointer): * [​`Pack`](./Pack): * [​`PositionSummary`](./PositionSummary): ## Traits * [​`OptionalPointer`](./OptionalPointer): ## Functions * [​`get_q_head_idx`](./get_q_head_idx): * [​`get_seq_info`](./get_seq_info): * [​`kv_coord`](./kv_coord): * [​`output_reg_to_smem`](./output_reg_to_smem): * [​`output_reg_to_smem_st_matrix`](./output_reg_to_smem_st_matrix): * [​`produce`](./produce): * [​`q_coord`](./q_coord): Returns the coordinates for a tma load on the `Q` matrix. This load can be 3D, 4D, or 5D. * [​`q_gmem_shape`](./q_gmem_shape): * [​`q_smem_shape`](./q_smem_shape): * [​`q_tma`](./q_tma):
--- ## kv_coord
`kv_coord[*, depth: Int, swizzle_granularity: Int](row: UInt32, head_idx: UInt32) -> StaticTuple[UInt32, 4 if _should_split_last_dim(depth, swizzle_granularity) else 3]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## output_reg_to_smem
`output_reg_to_smem[output_type: DType, accum_type: DType, num_m_mmas: Int, o_frag_size: Int, //, BM: Int, BN: Int, padded_depth: Int, swizzle: Swizzle, num_consumer: Int](tid: UInt32, local_warp_group_idx: UInt32, warp_y: UInt32, q_smem: LegacyUnsafePointer[Scalar[output_type], address_space=AddressSpace.SHARED], output_reg_tile: LayoutTensor[accum_type, Layout.row_major(num_m_mmas, o_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL]) -> LayoutTensor[output_type, Layout.row_major(BM, padded_depth), MutAnyOrigin, address_space=AddressSpace.SHARED]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## output_reg_to_smem_st_matrix
`output_reg_to_smem_st_matrix[output_type: DType, accum_type: DType, num_m_mmas: Int, o_frag_size: Int, //, BM: Int, padded_depth: Int, swizzle: Swizzle, num_consumer: Int](warp_group_thread_idx: UInt32, local_warp_group_idx: UInt32, output_reg_tile: LayoutTensor[accum_type, Layout.row_major(num_m_mmas, o_frag_size), MutAnyOrigin, address_space=AddressSpace.LOCAL], accum_smem_tile: LayoutTensor[output_type, Layout.row_major(BM, padded_depth), MutAnyOrigin, address_space=AddressSpace.SHARED])`
--- ## produce
`produce[qkv_type: DType, BM: Int, BN: Int, q_smem_layout: Layout, q_desc_layout: Layout, k_smem_layout: Layout, k_desc_layout: Layout, v_smem_layout: Layout, v_desc_layout: Layout, depth: Int, padded_depth: Int, num_heads: Int, group: Int, PartitionType: MHAPartitionScheme, MaxSeqLenType: OptionallyStaticInt, SchedulerType: MHATileScheduler, KVLUTType: MHAOperand, MaskType: MHAMask, KVInputRowOffsetsType: OptionalPointer, ValidLengthType: OptionalPointer, //, swizzle_mode: TensorMapSwizzle, *, pipeline_stages: Int, ragged: Bool, _is_cache_length_accurate: Bool](q_tma_op: TMATensorTile[qkv_type, q_smem_layout, q_desc_layout], k_tma_op: TMATensorTile[qkv_type, k_smem_layout, k_desc_layout], v_tma_op: TMATensorTile[qkv_type, v_smem_layout, v_desc_layout], q_smem: LegacyUnsafePointer[Scalar[qkv_type], address_space=AddressSpace.SHARED], kv_smem: LegacyUnsafePointer[Scalar[qkv_type], address_space=AddressSpace.SHARED], produced_mbar_kv: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], consumed_mbar_kv: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], produced_mbar_q: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], consumed_mbar_q: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], kv_lut: KVLUTType, initial_position: MHAPosition[BM, BN, depth, padded_depth, num_heads, group, _is_decoding[MaxSeqLenType]()], partition: PartitionType, scheduler: SchedulerType, mask: MaskType, tile_summary: MHATileSummary[ValidLengthType], tile_state_arg: MHATileState, max_seq_len: MaxSeqLenType, num_keys_arg: UInt32, kv_input_row_offsets: KVInputRowOffsetsType)`
--- ## q_coord
`q_coord[*, depth: Int, swizzle_granularity: Int, decoding: Bool](row: UInt32, head_idx: UInt32) -> StaticTuple[UInt32, (4 if decoding else 3 + Int.__init__[Bool](_should_split_last_dim(depth, swizzle_granularity)))]` Returns the coordinates for a tma load on the `Q` matrix. This load can be 3D, 4D, or 5D. Arguments: row: the row to load from. head\_idx: q\_head\_idx if prefill, kv\_head\_idx if decoding. **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## q_gmem_shape
`q_gmem_shape[dtype: DType, swizzle_mode: TensorMapSwizzle, *, group: Int, q_num_heads: Int, depth: Int, decoding: Bool]() -> IndexList[3 if (not decoding._mlir_value) else 5 if _should_split_last_dim[dtype](depth, swizzle_mode) else 4]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## q_smem_shape
`q_smem_shape[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, group: Int, depth: Int, decoding: Bool]() -> IndexList[3 if (not decoding._mlir_value) else 5 if _should_split_last_dim[dtype](depth, swizzle_mode) else 4]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## q_tma
`q_tma[dtype: DType, //, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int, q_num_heads: Int, group: Int, decoding: Bool](ctx: DeviceContext, ptr: LegacyUnsafePointer[Scalar[dtype]], rows: Int) -> TMATensorTile[dtype, _split_last_layout[dtype](q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding](), swizzle_mode, True), _ragged_desc_layout[dtype](q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding](), swizzle_mode)]` **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)
--- ## AndMask
`@register_passable(trivial)` `struct AndMask[T: MHAMask, S: MHAMask, //, lhs: T, rhs: S]` Mask that's the AND of two masks. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAMask`](/mojo/kernels/nn/mha_mask/MHAMask), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `apply_log2e_after_mask` `comptime apply_log2e_after_mask = T.apply_log2e_after_mask if T.apply_log2e_after_mask else S.apply_log2e_after_mask` ### `check_mask_during_decoding` `comptime check_mask_during_decoding = S.check_mask_during_decoding if T.check_mask_during_decoding else T.check_mask_during_decoding` ### `device_type` `comptime device_type = AndMask[lhs, rhs]` ### `mask_out_of_bound` `comptime mask_out_of_bound = T.mask_out_of_bound if T.mask_out_of_bound else S.mask_out_of_bound` ### `mask_safe_out_of_bounds` `comptime mask_safe_out_of_bounds = S.mask_safe_out_of_bounds if T.mask_safe_out_of_bounds else T.mask_safe_out_of_bounds` ## Methods ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `name` `static name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `mask` `mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `status` `status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `start_column` `start_column[BM: Int, BN: Int, page_size: Int](self, row: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `total_iters` `total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `count_nonfull_sets` `static count_nonfull_sets(BM: Int, BN: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `last_masked_set_end` `last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `masked_set_ends` `masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, AndMask.count_nonfull_sets[T, S, lhs, rhs](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `nonfull_sets` `static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, AndMask.count_nonfull_sets[T, S, lhs, rhs](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## CausalMask
`@register_passable(trivial)` `struct CausalMask` MHA causal mask ensures a token is only affected by previous tokens. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAMask`](/mojo/kernels/nn/mha_mask/MHAMask), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `apply_log2e_after_mask` `comptime apply_log2e_after_mask = False` ### `check_mask_during_decoding` `comptime check_mask_during_decoding = False` ### `device_type` `comptime device_type = CausalMask` ### `mask_out_of_bound` `comptime mask_out_of_bound = is_nvidia_gpu()` ### `mask_safe_out_of_bounds` `comptime mask_safe_out_of_bounds = True` ## Methods ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `name` `static name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `mask` `mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `status` `status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `start_column` `start_column[BM: Int, BN: Int, page_size: Int](self, row: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `total_iters` `total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `count_nonfull_sets` `static count_nonfull_sets(BM: Int, BN: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `last_masked_set_end` `last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `masked_set_ends` `masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, CausalMask.count_nonfull_sets(BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `nonfull_sets` `static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, CausalMask.count_nonfull_sets(BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## ChunkedCausalMask
`ChunkedCausalMask[local_window_size: Int]() -> OrMask[CausalMask(), ChunkedMask[local_window_size]()]` Mask implementing Chunked Causal attention for Llama4 models. This groups the mask into chunks of size `local_window_size` and performs causal attention within each local chunk. Considering the following case: * Q\_len = 7 * K\_len = 10 * start\_pos = 3 * local\_window\_size = 4 The mask will be applied as follows: K > 0 1 2 3 4 5 6 7 8 9 Q v x--------------------x 0 | 1 1 1 1 0 0 0 0 0 0 1 | 0 0 0 0 1 0 0 0 0 0 2 | 0 0 0 0 1 1 0 0 0 0 3 | 0 0 0 0 1 1 1 0 0 0 4 | 0 0 0 0 1 1 1 1 0 0 5 | 0 0 0 0 0 0 0 0 1 0 6 | 0 0 0 0 0 0 0 0 1 1 **Returns:** `OrMask`
--- ## ChunkedMask
`@register_passable(trivial)` `struct ChunkedMask[local_window_size: Int]` Mask implementing Chunked attention. This groups the mask into chunks of size `local_window_size`. Considering the following case: * Q\_len = 7 * K\_len = 10 * local\_window\_size = 4 The mask will be applied as follows: K > 0 1 2 3 4 5 6 7 8 9 Q v x--------------------x 0 | 1 1 1 1 0 0 0 0 0 0 1 | 0 0 0 0 1 1 1 1 0 0 2 | 0 0 0 0 1 1 1 1 0 0 3 | 0 0 0 0 1 1 1 1 0 0 4 | 0 0 0 0 1 1 1 1 0 0 5 | 0 0 0 0 0 0 0 0 1 1 6 | 0 0 0 0 0 0 0 0 1 1 ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAMask`](/mojo/kernels/nn/mha_mask/MHAMask), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `apply_log2e_after_mask` `comptime apply_log2e_after_mask = False` ### `check_mask_during_decoding` `comptime check_mask_during_decoding = True` ### `device_type` `comptime device_type = ChunkedMask[local_window_size]` ### `mask_out_of_bound` `comptime mask_out_of_bound = True` ### `mask_safe_out_of_bounds` `comptime mask_safe_out_of_bounds = True` ## Methods ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `name` `static name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `mask` `mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `status` `status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `start_column` `start_column[BM: Int, BN: Int, page_size: Int](self, row: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `total_iters` `total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `count_nonfull_sets` `static count_nonfull_sets(BM: Int, BN: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `last_masked_set_end` `last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `masked_set_ends` `masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, ChunkedMask.count_nonfull_sets[local_window_size](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `nonfull_sets` `static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, ChunkedMask.count_nonfull_sets[local_window_size](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## MHAMask
The MHAMask trait describes masks for MHA kernels, such as the causal mask. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `apply_log2e_after_mask` `comptime apply_log2e_after_mask` Does the mask require `log2e` to be applied after the mask, or can it be fused with the scaling? ### `check_mask_during_decoding` `comptime check_mask_during_decoding` Should we check the mask during decoding, or should we assume that it does not return `FULL_MASK`? ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ### `mask_out_of_bound` `comptime mask_out_of_bound` ### `mask_safe_out_of_bounds` `comptime mask_safe_out_of_bounds` Is the mask safe to read out of bounds? ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `mask` `mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self: _Self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]` Return mask vector at given coordinates. Arguments: coord is (seq\_id, head, q\_idx, k\_idx) score\_vec is at `coord` of the score matrix The functor could capture an mask tensor and add to the score e.g. Replit. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `status` `status[*, element_type: DType = DType.uint32](self: _Self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus` Given a tile's index range, return its masking status. **Returns:** `TileMaskStatus` ### `start_column` `start_column[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32) -> UInt32` Returns the first column for which this mask does not return `TileMaskStatus.FULL_MASK`. This may not be a multiple of `BN`, in which case iterating using `start_column` and `masked_set_ends` will not necessarilly produce the same set or number of iterations as iterating from `0` and checking `status` to skip. The return value of `total_iters` should be less than or equal to the number of non-skipped iterations. The practical consequence is that all warp group specializations within a kernel that loop over columns need to be in agreement. Either they all loop over all columns and check status to skip, or they loop using the `masked_set_ends`. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `total_iters` `total_iters[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> UInt32` The total number of column iterations for which this mask returns either `TileMaskStatus.NO_MASK' or 'TileMaskStatus.PARTIAL_MASK'. This is to be used by warp specializations that do not need to use `kv\_row\`. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `count_nonfull_sets` `static count_nonfull_sets(BM: Int, BN: Int) -> Int` The number of blocks that are all partial-masks or not masked. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `masked_set_ends` `masked_set_ends[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]` For each set of iterations in `nonfull_sets`, indicate the end idx belonging to that set (i.e., the last idx would be `end - 1`). Note that the final `masked_set_ends` may not necessarilly equal `total_iters`, if we have `UNKNOWN_MASK`s. In case of `UNKNOWN_MASK`s, `masked_set_ends` with tile-skipping must be used to have the correct kv\_row values at each iteration. **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `last_masked_set_end` `last_masked_set_end[BM: Int, BN: Int, page_size: Int](self: _Self, row: UInt32, num_cols: UInt32) -> UInt32` Equivalent to `masked_set_ends[BM,BN,page_size](row, num_cols)[-1]`. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `nonfull_sets` `static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, _Self.count_nonfull_sets(::Int,::Int)(BM, BN)]` For each set of iterations that are either partially masked or not masked, this indicates the mask status. `UNKNOWN_MASK` here is an indicator meaning that we should check the status at runtime. It is semantically equivalent to `partial`, but with the optimization hint that it's worth checking on each iteration at runtime for `FULL_MASK` (in which case we can skip the tile) or `NO_MASK` (in which case we can unswitch and avoid masking in an inner loop). **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `name` `static name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## MaskName
`struct MaskName` A tile's masking status. ## Fields * ​name (`String`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `CAUSAL` `comptime CAUSAL = MaskName("causal")` ### `CHUNKED` `comptime CHUNKED = MaskName("chunked")` ### `CHUNKED_CAUSAL` `comptime CHUNKED_CAUSAL = MaskName("chunked_causal")` ### `MATERIALIZED` `comptime MATERIALIZED = MaskName("materialized")` ### `NULL` `comptime NULL = MaskName("null")` ### `SLIDING_WINDOW_CAUSAL` `comptime SLIDING_WINDOW_CAUSAL = MaskName("sliding_window_causal")` ## Methods ### `__init__` `__init__(out self, name: String)` ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) `__eq__(self, rhs: String) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String)
--- ## MaterializedMask
`@register_passable(trivial)` `struct MaterializedMask[dtype_: DType, layout_: Layout]` Mask that's backed by a materialized tensor. ## Fields * ​mask\_tensor (`MaterializedMask[dtype_, layout_].MaskType`): * ​start\_pos (`OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]]`): * ​is\_multiple\_of\_2 (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAMask`](/mojo/kernels/nn/mha_mask/MHAMask), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `apply_log2e_after_mask` `comptime apply_log2e_after_mask = True` ### `check_mask_during_decoding` `comptime check_mask_during_decoding = True` ### `device_type` `comptime device_type = MaterializedMask[dtype_, layout_]` ### `mask_out_of_bound` `comptime mask_out_of_bound = True` ### `mask_safe_out_of_bounds` `comptime mask_safe_out_of_bounds = False` ### `MaskType` `comptime MaskType = LayoutTensor[dtype_, layout_, MutAnyOrigin]` ## Methods ### `__init__` `__init__(mask_tensor: LayoutTensor[dtype_, layout_, MutAnyOrigin], start_pos: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `name` `static name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_start_pos` `get_start_pos(self, batch_idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `mask` `mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `status` `status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `start_column` `start_column[BM: Int, BN: Int, page_size: Int](self, row: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `total_iters` `total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `count_nonfull_sets` `static count_nonfull_sets(BM: Int, BN: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `last_masked_set_end` `last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `masked_set_ends` `masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, MaterializedMask.count_nonfull_sets[dtype_, layout_](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `nonfull_sets` `static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, MaterializedMask.count_nonfull_sets[dtype_, layout_](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## NullMask
`@register_passable(trivial)` `struct NullMask` Mask that's effectively a noop. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAMask`](/mojo/kernels/nn/mha_mask/MHAMask), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `apply_log2e_after_mask` `comptime apply_log2e_after_mask = False` ### `check_mask_during_decoding` `comptime check_mask_during_decoding = False` ### `device_type` `comptime device_type = NullMask` ### `mask_out_of_bound` `comptime mask_out_of_bound = True` ### `mask_safe_out_of_bounds` `comptime mask_safe_out_of_bounds = True` ## Methods ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `name` `static name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `mask` `mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `status` `status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `start_column` `start_column[BM: Int, BN: Int, page_size: Int](self, row: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `total_iters` `total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` The total number of column iterations for which this mask returns either \`TileMaskStatus.NO\_MASK' or 'TileMaskStatus.PARTIAL\_MASK'. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `last_masked_set_end` `last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `count_nonfull_sets` `static count_nonfull_sets(BM: Int, BN: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `masked_set_ends` `masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, NullMask.count_nonfull_sets(BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `nonfull_sets` `static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, NullMask.count_nonfull_sets(BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## OrMask
`@register_passable(trivial)` `struct OrMask[T: MHAMask, S: MHAMask, //, lhs: T, rhs: S]` Mask that's the OR of two masks. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAMask`](/mojo/kernels/nn/mha_mask/MHAMask), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `apply_log2e_after_mask` `comptime apply_log2e_after_mask = T.apply_log2e_after_mask if T.apply_log2e_after_mask else S.apply_log2e_after_mask` ### `check_mask_during_decoding` `comptime check_mask_during_decoding = T.check_mask_during_decoding if T.check_mask_during_decoding else S.check_mask_during_decoding` ### `device_type` `comptime device_type = OrMask[lhs, rhs]` ### `mask_out_of_bound` `comptime mask_out_of_bound = S.mask_out_of_bound if T.mask_out_of_bound else T.mask_out_of_bound` ### `mask_safe_out_of_bounds` `comptime mask_safe_out_of_bounds = S.mask_safe_out_of_bounds if T.mask_safe_out_of_bounds else T.mask_safe_out_of_bounds` ## Methods ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `name` `static name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `mask` `mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `status` `status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `start_column` `start_column[BM: Int, BN: Int, page_size: Int](self, row: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `total_iters` `total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `count_nonfull_sets` `static count_nonfull_sets(BM: Int, BN: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `last_masked_set_end` `last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `masked_set_ends` `masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, OrMask.count_nonfull_sets[T, S, lhs, rhs](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `nonfull_sets` `static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, OrMask.count_nonfull_sets[T, S, lhs, rhs](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## SlidingWindowCausalMask
`@register_passable(trivial)` `struct SlidingWindowCausalMask[window_size: Int]` Mask implementing Sliding Window attention. Considering the following case: * Q\_len = 7 * K\_len = 7 * window\_size = 3 The mask will be applied as follows: K > 0 1 2 3 4 5 6 Q v x------------x 0 | 1 0 0 0 0 0 0 1 | 1 1 0 0 0 0 0 2 | 1 1 1 0 0 0 0 3 | 0 1 1 1 0 0 0 4 | 0 0 1 1 1 0 0 5 | 0 0 0 1 1 1 0 6 | 0 0 0 0 1 1 1 ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAMask`](/mojo/kernels/nn/mha_mask/MHAMask), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `apply_log2e_after_mask` `comptime apply_log2e_after_mask = False` ### `check_mask_during_decoding` `comptime check_mask_during_decoding = True` ### `device_type` `comptime device_type = SlidingWindowCausalMask[window_size]` ### `mask_out_of_bound` `comptime mask_out_of_bound = True` ### `mask_safe_out_of_bounds` `comptime mask_safe_out_of_bounds = True` ## Methods ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `name` `static name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `mask` `mask[dtype: DType, width: Int, //, *, element_type: DType = DType.uint32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `status` `status[*, element_type: DType = DType.uint32](self, tile_offset: IndexList[2, element_type=element_type], tile_size: IndexList[2, element_type=element_type]) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `start_column` `start_column[BM: Int, BN: Int, page_size: Int](self, row: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `total_iters` `total_iters[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `count_nonfull_sets` `static count_nonfull_sets(BM: Int, BN: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `masked_set_ends` `masked_set_ends[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> StaticTuple[UInt32, SlidingWindowCausalMask.count_nonfull_sets[window_size](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple) ### `last_masked_set_end` `last_masked_set_end[BM: Int, BN: Int, page_size: Int](self, row: UInt32, num_cols: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `nonfull_sets` `static nonfull_sets[BM: Int, BN: Int]() -> StaticTuple[TileMaskStatus, SlidingWindowCausalMask.count_nonfull_sets[window_size](BM, BN)]` **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)
--- ## TileMaskStatus
`@register_passable(trivial)` `struct TileMaskStatus` A tile's masking status. ## Fields * ​status (`UInt8`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Identifiable`](/mojo/stdlib/builtin/identifiable/Identifiable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `FULL_MASK` `comptime FULL_MASK = TileMaskStatus(3)` ### `NO_MASK` `comptime NO_MASK = TileMaskStatus(0)` ### `PARTIAL_MASK` `comptime PARTIAL_MASK = TileMaskStatus(1)` ### `UNKNOWN_MASK` `comptime UNKNOWN_MASK = TileMaskStatus(4)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__is__` `__is__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__and__` `__and__(self, rhs: Self) -> Self` ### `__or__` `__or__(self, rhs: Self) -> Self` ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## mha_mask
## `comptime` values ### `MASK_VALUE` `comptime MASK_VALUE = -10000` ## Structs * [​`AndMask`](./AndMask): Mask that's the AND of two masks. * [​`CausalMask`](./CausalMask): MHA causal mask ensures a token is only affected by previous tokens. * [​`ChunkedMask`](./ChunkedMask): Mask implementing Chunked attention. * [​`MaskName`](./MaskName): A tile's masking status. * [​`MaterializedMask`](./MaterializedMask): Mask that's backed by a materialized tensor. * [​`NullMask`](./NullMask): Mask that's effectively a noop. * [​`OrMask`](./OrMask): Mask that's the OR of two masks. * [​`SlidingWindowCausalMask`](./SlidingWindowCausalMask): Mask implementing Sliding Window attention. * [​`TileMaskStatus`](./TileMaskStatus): A tile's masking status. ## Traits * [​`MHAMask`](./MHAMask): The MHAMask trait describes masks for MHA kernels, such as the causal mask. ## Functions * [​`ChunkedCausalMask`](./ChunkedCausalMask): Mask implementing Chunked Causal attention for Llama4 models. * [​`naively_compute_total_iters`](./naively_compute_total_iters): * [​`naively_get_first_nonempty_mask_col`](./naively_get_first_nonempty_mask_col):
--- ## naively_compute_total_iters
`naively_compute_total_iters[MaskType: MHAMask, //, BM: Int, BN: Int](mask: MaskType, q_row: UInt32, end: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## naively_get_first_nonempty_mask_col
`naively_get_first_nonempty_mask_col[MaskType: MHAMask, //, BM: Int, BN: Int](mask: MaskType, q_row: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## KVCacheMHAOperand
`@register_passable(trivial)` `struct KVCacheMHAOperand[cache_t: KVCacheT]` An implementation for `mo.opaque` KVCacheT arguments to MHA kernels. We can eventually remove this trait and just add it as a sub-trait in the KVCacheT type, but we need to solve some cyclic dependencies first. ## Fields * ​cache (`cache_t`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAOperand`](/mojo/kernels/nn/mha_operand/MHAOperand), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = cache_t.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = cache_t.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = cache_t.__moveinit__is_trivial` ### `device_type` `comptime device_type = KVCacheMHAOperand[cache_t]` ### `dtype` `comptime dtype = cache_t.dtype` ### `page_size` `comptime page_size = cache_t.page_size_` ## Methods ### `__init__` `__init__(cache: cache_t) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = 0) -> LegacyUnsafePointer[Scalar[KVCacheMHAOperand[cache_t].dtype]]` **Returns:** `LegacyUnsafePointer` ### `cache_length` `cache_length(self, batch_idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `max_context_length` `max_context_length(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `row_idx` `row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[BN: Int, depth: Int, swizzle_mode: TensorMapSwizzle, BK: Int = depth](self, ctx: DeviceContext, out tma: TMATensorTile[KVCacheMHAOperand[cache_t].dtype, _split_last_layout[KVCacheMHAOperand[cache_t].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[KVCacheMHAOperand[cache_t].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)])` Creates a TMA tile for efficient GPU memory transfers. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)
--- ## LayoutTensorMHAOperand
`@register_passable(trivial)` `struct LayoutTensorMHAOperand[dtype_: DType, layout: Layout]` An implementation for NDBuffer arguments to MHA kernels. ## Fields * ​buffer (`LayoutTensor[LayoutTensorMHAOperand[dtype_, layout].dtype, layout, MutAnyOrigin]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAOperand`](/mojo/kernels/nn/mha_operand/MHAOperand), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = LayoutTensorMHAOperand[dtype_, layout]` ### `dtype` `comptime dtype = dtype_` ### `page_size` `comptime page_size = 0` ## Methods ### `__init__` `__init__(buffer: LayoutTensor[LayoutTensorMHAOperand[dtype_, layout].dtype, layout, MutAnyOrigin]) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = 0) -> LegacyUnsafePointer[Scalar[LayoutTensorMHAOperand[dtype_, layout].dtype]]` **Returns:** `LegacyUnsafePointer` ### `cache_length` `cache_length(self, batch_idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `max_context_length` `max_context_length(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `row_idx` `row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[BN: Int, depth: Int, swizzle_mode: TensorMapSwizzle, BK: Int = depth](self, ctx: DeviceContext, out tma: TMATensorTile[LayoutTensorMHAOperand[dtype_, layout].dtype, _split_last_layout[LayoutTensorMHAOperand[dtype_, layout].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[LayoutTensorMHAOperand[dtype_, layout].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)])` Creates a TMA tile for efficient GPU memory transfers. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)
--- ## MHAOperand
This serves as the trait to support arguments to our MHA kernel. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ### `dtype` `comptime dtype` ### `page_size` `comptime page_size` ## Required methods ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self: _Self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = 0) -> LegacyUnsafePointer[Scalar[_Self.dtype]]` **Returns:** `LegacyUnsafePointer` ### `cache_length` `cache_length(self: _Self, batch_idx: Int) -> Int` Returns the length of the cache for a given batch index. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `max_context_length` `max_context_length(self: _Self) -> UInt32` Returns the maximum cache length in a given batch index. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `row_idx` `row_idx(self: _Self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[BN: Int, depth: Int, swizzle_mode: TensorMapSwizzle, BK: Int = depth](self: _Self, ctx: DeviceContext) -> TMATensorTile[_Self.dtype, _split_last_layout[_Self.dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[_Self.dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)]` Creates a TMA tile for efficient GPU memory transfers. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile) ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name.
--- ## RaggedMHAOperand
`@register_passable(trivial)` `struct RaggedMHAOperand[dtype_: DType, layout: Layout, cache_layout: Layout]` An implementation for ragged NDBuffer arguments to MHA kernels. ## Fields * ​buffer (`LayoutTensor[RaggedMHAOperand[dtype_, layout, cache_layout].dtype, layout, MutAnyOrigin]`): * ​cache\_row\_offsets (`LayoutTensor[DType.uint32, cache_layout, MutAnyOrigin]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAOperand`](/mojo/kernels/nn/mha_operand/MHAOperand), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = RaggedMHAOperand[dtype_, layout, cache_layout]` ### `dtype` `comptime dtype = dtype_` ### `page_size` `comptime page_size = 0` ## Methods ### `__init__` `__init__(buffer: LayoutTensor[RaggedMHAOperand[dtype_, layout, cache_layout].dtype, layout, MutAnyOrigin], cache_row_offsets: LayoutTensor[DType.uint32, cache_layout, MutAnyOrigin]) -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `block_paged_ptr` `block_paged_ptr[tile_size: Int](self, batch_idx: UInt32, start_tok_idx: UInt32, head_idx: UInt32, head_dim_idx: UInt32 = 0) -> LegacyUnsafePointer[Scalar[RaggedMHAOperand[dtype_, layout, cache_layout].dtype]]` **Returns:** `LegacyUnsafePointer` ### `cache_length` `cache_length(self, batch_idx: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `max_context_length` `max_context_length(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `row_idx` `row_idx(self, batch_idx: UInt32, start_tok_idx: UInt32) -> UInt32` Returns the row idx when viewing the memory as a matrix. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `create_tma_tile` `create_tma_tile[BN: Int, depth: Int, swizzle_mode: TensorMapSwizzle, BK: Int = depth](self, ctx: DeviceContext, out tma: TMATensorTile[RaggedMHAOperand[dtype_, layout, cache_layout].dtype, _split_last_layout[RaggedMHAOperand[dtype_, layout, cache_layout].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode, True), _ragged_desc_layout[RaggedMHAOperand[dtype_, layout, cache_layout].dtype](IndexList[3, DType.int64](BN, 1, BK, Tuple[]()), swizzle_mode)])` Creates a TMA tile for efficient GPU memory transfers. **Returns:** [`TMATensorTile`](/mojo/kernels/layout/tma_async/TMATensorTile)
--- ## mha_operand
## Structs * [​`KVCacheMHAOperand`](./KVCacheMHAOperand): An implementation for `mo.opaque` KVCacheT arguments to MHA kernels. * [​`LayoutTensorMHAOperand`](./LayoutTensorMHAOperand): An implementation for NDBuffer arguments to MHA kernels. * [​`RaggedMHAOperand`](./RaggedMHAOperand): An implementation for ragged NDBuffer arguments to MHA kernels. ## Traits * [​`MHAOperand`](./MHAOperand): This serves as the trait to support arguments to our MHA kernel.
--- ## AlibiScoreMod
`@register_passable(trivial)` `struct AlibiScoreMod[num_heads: Int]` AlibiScoreMod adds the appropriate ALiBi constant bias to attention score. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`ScoreModTrait`](/mojo/kernels/nn/mha_score_mod/ScoreModTrait), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = AlibiScoreMod[num_heads]` ### `name_str` `comptime name_str = "alibi"` ## Methods ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `score_mod` `score_mod[dtype: DType, width: Int, //, *, element_type: DType = DType.int32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width], max_prompt_len: Int) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## IdentityScoreMod
`@register_passable(trivial)` `struct IdentityScoreMod` IdentityScoreMod simply returns attention score. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`ScoreModTrait`](/mojo/kernels/nn/mha_score_mod/ScoreModTrait), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = IdentityScoreMod` ### `name_str` `comptime name_str = "no_pos"` ## Methods ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `score_mod` `score_mod[dtype: DType, width: Int, //, *, element_type: DType = DType.int32](self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width], max_prompt_len: Int = 0) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## ScoreModTrait
The ScoreMod trait desctribes score\_mod for mha kernel like alibi bias. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ### `name_str` `comptime name_str` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `score_mod` `score_mod[dtype: DType, width: Int, //, *, element_type: DType = DType.int32](self: _Self, coord: IndexList[4, element_type=element_type], score_vec: SIMD[dtype, width], max_prompt_len: Int = 0) -> SIMD[dtype, width]` Return score vector at given coordinates given a score\_mod. Arguments: coord is (seq\_id, head, q\_idx, k\_idx) score\_vec is at `coord` of the score matrix Score\_mod calculates a tensor given the functor and adds to score\_vec. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## mha_score_mod
## Structs * [​`AlibiScoreMod`](./AlibiScoreMod): AlibiScoreMod adds the appropriate ALiBi constant bias to attention score. * [​`IdentityScoreMod`](./IdentityScoreMod): IdentityScoreMod simply returns attention score. ## Traits * [​`ScoreModTrait`](./ScoreModTrait): The ScoreMod trait desctribes score\_mod for mha kernel like alibi bias.
--- ## AccumulatorTile
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `dtype` `comptime dtype` ### `element_layout` `comptime element_layout` ### `rows_of_frags_layout` `comptime rows_of_frags_layout` ### `vec_output_layout` `comptime vec_output_layout` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `rows_of_frags` `static rows_of_frags(src: LayoutTensor[_Self.dtype, _Self.vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=_Self.element_layout]) -> LayoutTensor[_Self.dtype, _Self.rows_of_frags_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `allocate_register_tile` `static allocate_register_tile() -> LayoutTensor[_Self.dtype, _Self.vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=_Self.element_layout]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `copy_from` `copy_from(self: _Self, src: LayoutTensor[_Self.dtype, _Self.vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=_Self.element_layout])` ### `copy_to` `copy_to(self: _Self, dst: LayoutTensor[_Self.dtype, _Self.vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=_Self.element_layout])` ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## DescriptorPair
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `a_t` `comptime a_t` ### `b_t` `comptime b_t` ## Required methods ### `get_a` `get_a(self: _Self) -> _Self.a_t` **Returns:** `_Self.a_t` ### `get_b` `get_b(self: _Self) -> _Self.b_t` **Returns:** `_Self.b_t`
--- ## DescriptorPairTS
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `a_t` `comptime a_t` ### `b_t` `comptime b_t` ## Required methods ### `get_a` `get_a(self: _Self) -> _Self.a_t` **Returns:** `_Self.a_t` ### `get_b` `get_b(self: _Self) -> _Self.b_t` **Returns:** `_Self.b_t`
--- ## MMAOperandOffsetFn
`@register_passable(trivial)` `struct MMAOperandOffsetFn[dtype: DType, BMN: Int, BK: Int, swizzle: TensorMapSwizzle, is_k_major: Bool, WMMA_MN: Int, WMMA_K: Int]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `canonical_K` `comptime canonical_K = (swizzle.bytes() // size_of[dtype]()) if (swizzle != TensorMapSwizzle(Int32(0))) else BK` ### `canonical_layout` `comptime canonical_layout = tile_to_descriptor[dtype, MMAOperandOffsetFn[dtype, BMN, BK, swizzle, is_k_major, WMMA_MN, WMMA_K].canonical_layout_flat, is_k_major]()` ### `canonical_layout_flat` `comptime canonical_layout_flat = tile_layout_k_major[dtype, BMN, MMAOperandOffsetFn[dtype, BMN, BK, swizzle, is_k_major, WMMA_MN, WMMA_K].canonical_K, swizzle]() if is_k_major else MMAOperandOffsetFn[dtype, BMN, BK, swizzle, is_k_major, WMMA_MN, WMMA_K].layout` ### `canonical_layout_size` `comptime canonical_layout_size = MMAOperandOffsetFn[dtype, BMN, BK, swizzle, is_k_major, WMMA_MN, WMMA_K].canonical_layout.size()` ### `layout` `comptime layout = tile_layout_k_major[dtype, BMN, BK, swizzle]() if is_k_major else tile_layout_mn_major[dtype, BMN, BK, swizzle]()` ### `layout_size` `comptime layout_size = MMAOperandOffsetFn[dtype, BMN, BK, swizzle, is_k_major, WMMA_MN, WMMA_K].layout.size()` ## Methods ### `__init__` `__init__() -> Self`
--- ## RegisterAccumulatorDescription
`struct RegisterAccumulatorDescription` ## Fields * ​num\_mmas (`Int`): * ​frag\_size (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, num_mmas: Int, frag_size: Int)`
--- ## RegisterAccumulatorLayout
`@register_passable(trivial)` `struct RegisterAccumulatorLayout[MMA_M: Int, MMA_N: Int, num_m_mmas: Int, num_n_mmas: Int, consumer_group_size: Int, *, frag_simdwidth: Int = 2]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `element_layout` `comptime element_layout = Layout.row_major(1, frag_simdwidth)` ### `frag_size` `comptime frag_size = ((MMA_M * MMA_N) // consumer_group_size)` ### `num_row_blocks_per_mma` `comptime num_row_blocks_per_mma = 2` ### `rows_of_frags_layout` `comptime rows_of_frags_layout = Layout.row_major((num_m_mmas * num_n_mmas), RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, consumer_group_size, frag_simdwidth=frag_simdwidth].frag_size)` ### `vec_output_layout` `comptime vec_output_layout = Layout(IntTuple(IntTuple(2, num_m_mmas), IntTuple((RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, consumer_group_size, frag_simdwidth=frag_simdwidth].frag_size // (2 * frag_simdwidth)), num_n_mmas), Tuple[]()), IntTuple(IntTuple(frag_simdwidth, RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, consumer_group_size, frag_simdwidth=frag_simdwidth].frag_size), IntTuple((2 * frag_simdwidth), (num_m_mmas * RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, consumer_group_size, frag_simdwidth=frag_simdwidth].frag_size)), Tuple[]()))` ## Methods ### `description` `static description() -> RegisterAccumulatorDescription` **Returns:** `RegisterAccumulatorDescription`
--- ## SM100TensorAccumulatorSS
`@register_passable(trivial)` `struct SM100TensorAccumulatorSS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, compute_BK: Int, num_softmax_threads: Int, swizzle_a: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, *, transpose_b: Bool = True, cta_group: Int = 1, pipeline_stages: Int = 1]` ## Fields * ​mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​pipeline (`PipelineState[pipeline_stages]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_offset` `comptime a_offset = MMAOperandOffsetFn[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t, BM, BK, swizzle_a, True, MMA_M, 16]()` ### `a_t` `comptime a_t = SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].ab_t.a_t` ### `ab_t` `comptime ab_t = UMMADescriptorSS[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t]` ### `accum_t` `comptime accum_t = accum_type` ### `b_offset` `comptime b_offset = MMAOperandOffsetFn[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t, BN, BK, swizzle_b, transpose_b, MMA_N, 16]()` ### `b_t` `comptime b_t = SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].ab_t.b_t` ### `c_t` `comptime c_t = TMemAccumulator[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, (BM // SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_n_mmas, num_softmax_threads]` ### `idesc` `comptime idesc = UMMAInsDescriptor.create[UMMAKind.KIND_F16, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t, Index[dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()` ### `MMA_K` `comptime MMA_K = 16` ### `num_k_mmas` `comptime num_k_mmas = (compute_BK // 16)` ### `num_m_blocks_per_warp` `comptime num_m_blocks_per_warp = ((2 * BM) // num_softmax_threads)` ### `num_m_mmas` `comptime num_m_mmas = (BM // MMA_M)` ### `num_n_mmas` `comptime num_n_mmas = (BN // MMA_N)` ### `operand_t` `comptime operand_t = operand_type` ### `smem_ptr_t` `comptime smem_ptr_t = LegacyUnsafePointer[Scalar[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].operand_t], address_space=AddressSpace.SHARED]` ## Methods ### `__init__` `__init__(smem: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `check_constraints` `static check_constraints()` ### `init` `init(self)` ### `mma_descriptors` `static mma_descriptors[dtype_a: DType, dtype_b: DType](p_a: LegacyUnsafePointer[Scalar[dtype_a], address_space=AddressSpace.SHARED], p_b: LegacyUnsafePointer[Scalar[dtype_b], address_space=AddressSpace.SHARED]) -> SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].ab_t` **Returns:** `SM100TensorAccumulatorSS` ### `mma` `mma(mut self, a: MMASmemDescriptor, b: MMASmemDescriptor, c_base: TMemAccumulator[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, (BM // SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_n_mmas, num_softmax_threads], scale_c: UInt32)` ### `wait_for_tmem` `wait_for_tmem(self)` Wait for the accumulator tmem to finish being read. ### `wait_for_mma` `wait_for_mma(self, c_base: TMemAccumulator[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].accum_t, (BM // SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_m_blocks_per_warp, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].num_n_mmas, num_softmax_threads]) -> SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, compute_BK, num_softmax_threads, swizzle_a, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, pipeline_stages=pipeline_stages].c_t` Wait for the accumulator tmem to finish being read. **Returns:** `SM100TensorAccumulatorSS` ### `tmem_arrive_init` `tmem_arrive_init(self)` ### `tmem_arrive` `tmem_arrive(mut self)` Indicate that the accumulator is ready to be updated.
--- ## SM100TensorAccumulatorTS
`@register_passable(trivial)` `struct SM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BM: Int, BN: Int, BK: Int, num_softmax_threads: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = 1]` ## Fields * ​mbar (`LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]`): * ​phase (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_frag_size` `comptime a_frag_size = ((MMA_M * 16) // num_softmax_threads)` ### `a_t` `comptime a_t = SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.a_t` ### `ab_t` `comptime ab_t = UMMADescriptorTS[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, MMA_M=(BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N=BK, MMA_K=16, consumer_group_size=num_softmax_threads]` ### `accum_t` `comptime accum_t = accum_type` ### `b_offset` `comptime b_offset = MMAOperandOffsetFn[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, BN, BK, swizzle_b, transpose_b, MMA_N, 16]()` ### `b_t` `comptime b_t = SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.b_t` ### `c_frag_size` `comptime c_frag_size = ((MMA_M * MMA_N) // num_softmax_threads)` ### `c_t` `comptime c_t = TMemAccumulator[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads]` ### `idesc` `comptime idesc = UMMAInsDescriptor.create[UMMAKind.KIND_F16, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, Index[dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()` ### `MMA_K` `comptime MMA_K = 16` ### `num_k_mmas` `comptime num_k_mmas = (BK // 16)` ### `num_m_blocks_per_warp` `comptime num_m_blocks_per_warp = ((2 * BM) // num_softmax_threads)` ### `num_m_mmas` `comptime num_m_mmas = (BM // MMA_M)` ### `num_n_mmas` `comptime num_n_mmas = (BN // MMA_N)` ### `operand_t` `comptime operand_t = operand_type` ### `smem_ptr_t` `comptime smem_ptr_t = LegacyUnsafePointer[Scalar[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t], address_space=AddressSpace.SHARED]` ## Methods ### `__init__` `__init__(smem: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `check_constraints` `static check_constraints()` ### `init` `init(self)` ### `a_mma_descriptor` `static a_mma_descriptor(a_tmem: UInt32) -> SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.a_t` **Returns:** `SM100TensorAccumulatorTS` ### `b_mma_descriptor` `static b_mma_descriptor[dtype_b: DType](p_b: LegacyUnsafePointer[Scalar[dtype_b], address_space=AddressSpace.SHARED]) -> SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].ab_t.b_t` **Returns:** `SM100TensorAccumulatorTS` ### `mma` `mma(self, a: TMemOperand[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), BK, 16, num_softmax_threads], b: MMASmemDescriptor, c: TMemAccumulator[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].accum_t, (BM // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp), MMA_N, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_m_blocks_per_warp, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BM, BN, BK, num_softmax_threads, swizzle_b, transpose_b, cta_group].num_n_mmas, num_softmax_threads], c_scale: UInt32)` ### `wait` `wait(mut self, idx: UInt32)` ### `wait_for_mma` `wait_for_mma(mut self)` Wait for the mma to be complete. ### `wait_for_tmem` `wait_for_tmem(mut self)` Wait for the `output` and `A` tmem to be ready. ### `tmem_arrive` `tmem_arrive(self)` Indicate that the accumulator and the tensor memory arguments are ready for the MMA to begin.
--- ## TMemAccumulator
`@register_passable(trivial)` `struct TMemAccumulator[dtype_: DType, MMA_M: Int, MMA_N: Int, num_m_mmas: Int, num_n_mmas: Int, num_softmax_threads: Int]` ## Fields * ​tmem\_addr (`UInt32`): ## Implemented traits [`AccumulatorTile`](/mojo/kernels/nn/mha_sm100_1q/AccumulatorTile), [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `dtype` `comptime dtype = dtype_` ### `element_layout` `comptime element_layout = TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].layout_t.element_layout` ### `frag_size` `comptime frag_size = TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].layout_t.frag_size` ### `layout_t` `comptime layout_t = RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads]` ### `rows_of_frags_layout` `comptime rows_of_frags_layout = TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].layout_t.rows_of_frags_layout` ### `vec_output_layout` `comptime vec_output_layout = TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].layout_t.vec_output_layout` ## Methods ### `__init__` `__init__(tmem_addr: UInt32) -> Self` ### `__getitem__` `__getitem__(self, i: UInt32) -> Self` ### `check_constraints` `static check_constraints()` ### `offset` `offset[m_mma: Int, n_mma: Int](self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `rows_of_frags` `static rows_of_frags(src: LayoutTensor[TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].dtype, TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].layout_t.element_layout]) -> LayoutTensor[TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].dtype, TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].rows_of_frags_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `allocate_register_tile` `static allocate_register_tile() -> LayoutTensor[TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].dtype, TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].layout_t.element_layout]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `copy_from` `copy_from(self, src: LayoutTensor[TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].dtype, TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].layout_t.element_layout])` ### `copy_to` `copy_to(self, dst: LayoutTensor[TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].dtype, TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=TMemAccumulator[dtype_, MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads].layout_t.element_layout])`
--- ## TMemOperand
`@register_passable(trivial)` `struct TMemOperand[dtype: DType, num_m_mmas: Int, num_n_mmas: Int, MMA_M: Int, MMA_N: Int, MMA_K: Int, num_softmax_threads: Int]` ## Fields * ​tmem\_addr (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`WriteableMMAOperandDescriptor`](/mojo/kernels/nn/mha_sm100_1q/WriteableMMAOperandDescriptor) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `frag_size` `comptime frag_size = TMemOperand[dtype, num_m_mmas, num_n_mmas, MMA_M, MMA_N, MMA_K, num_softmax_threads].reg_layout.frag_size` ### `reg_layout` `comptime reg_layout = RegisterAccumulatorLayout[MMA_M, MMA_N, num_m_mmas, num_n_mmas, num_softmax_threads]` ### `reg_tile_t` `comptime reg_tile_t = LayoutTensor[dtype, TMemOperand[dtype, num_m_mmas, num_n_mmas, MMA_M, MMA_N, MMA_K, num_softmax_threads].vec_output_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=TMemOperand[dtype, num_m_mmas, num_n_mmas, MMA_M, MMA_N, MMA_K, num_softmax_threads].reg_layout.element_layout]` ### `vec_output_layout` `comptime vec_output_layout = TMemOperand[dtype, num_m_mmas, num_n_mmas, MMA_M, MMA_N, MMA_K, num_softmax_threads].reg_layout.vec_output_layout` ## Methods ### `__init__` `__init__(tmem_addr: UInt32) -> Self` ### `offset` `offset[m_mma: Int, k_mma: Int](self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `copy_from` `copy_from[src_type: DType, src_layout: Layout, src_element_layout: Layout, //](self, src: LayoutTensor[src_type, src_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=src_element_layout])` ### `copy_to` `copy_to[dst_type: DType, dst_layout: Layout, dst_element_layout: Layout, //](self, dst: LayoutTensor[dst_type, dst_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=dst_element_layout])`
--- ## UMMADescriptorSS
`@register_passable(trivial)` `struct UMMADescriptorSS[operand_type: DType]` ## Fields * ​a (`UMMADescriptorSS[operand_type].a_t`): * ​b (`UMMADescriptorSS[operand_type].b_t`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DescriptorPair`](/mojo/kernels/nn/mha_sm100_1q/DescriptorPair), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_t` `comptime a_t = MMASmemDescriptor` ### `b_t` `comptime b_t = MMASmemDescriptor` ### `operand_t` `comptime operand_t = operand_type` ## Methods ### `__init__` `__init__(a: MMASmemDescriptor, b: MMASmemDescriptor) -> Self` ### `get_a` `get_a(self) -> UMMADescriptorSS[operand_type].a_t` **Returns:** `UMMADescriptorSS` ### `get_b` `get_b(self) -> UMMADescriptorSS[operand_type].b_t` **Returns:** `UMMADescriptorSS`
--- ## UMMADescriptorTS
`@register_passable(trivial)` `struct UMMADescriptorTS[operand_type: DType, num_m_mmas: Int, num_n_mmas: Int, *, MMA_M: Int, MMA_N: Int, MMA_K: Int, consumer_group_size: Int]` ## Fields * ​a (`UMMADescriptorTS[operand_type, num_m_mmas, num_n_mmas, MMA_M=MMA_M, MMA_N=MMA_N, MMA_K=MMA_K, consumer_group_size=consumer_group_size].a_t`): * ​b (`UMMADescriptorTS[operand_type, num_m_mmas, num_n_mmas, MMA_M=MMA_M, MMA_N=MMA_N, MMA_K=MMA_K, consumer_group_size=consumer_group_size].b_t`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DescriptorPairTS`](/mojo/kernels/nn/mha_sm100_1q/DescriptorPairTS), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_t` `comptime a_t = TMemOperand[operand_type, num_m_mmas, num_n_mmas, MMA_M, MMA_N, MMA_K, consumer_group_size]` ### `b_t` `comptime b_t = MMASmemDescriptor` ### `operand_t` `comptime operand_t = operand_type` ## Methods ### `__init__` `__init__(a: TMemOperand[operand_type, num_m_mmas, num_n_mmas, MMA_M, MMA_N, MMA_K, consumer_group_size], b: MMASmemDescriptor) -> Self` ### `get_a` `get_a(self) -> UMMADescriptorTS[operand_type, num_m_mmas, num_n_mmas, MMA_M=MMA_M, MMA_N=MMA_N, MMA_K=MMA_K, consumer_group_size=consumer_group_size].a_t` **Returns:** `UMMADescriptorTS` ### `get_b` `get_b(self) -> UMMADescriptorTS[operand_type, num_m_mmas, num_n_mmas, MMA_M=MMA_M, MMA_N=MMA_N, MMA_K=MMA_K, consumer_group_size=consumer_group_size].b_t` **Returns:** `UMMADescriptorTS`
--- ## WriteableMMAOperandDescriptor
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `copy_from` `copy_from[src_type: DType, src_layout: Layout, src_element_layout: Layout, //](self: _Self, src: LayoutTensor[src_type, src_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=src_element_layout])`
--- ## mha_sm100_1q
## `comptime` values ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ## Structs * [​`MMAOperandOffsetFn`](./MMAOperandOffsetFn): * [​`RegisterAccumulatorDescription`](./RegisterAccumulatorDescription): * [​`RegisterAccumulatorLayout`](./RegisterAccumulatorLayout): * [​`SM100TensorAccumulatorSS`](./SM100TensorAccumulatorSS): * [​`SM100TensorAccumulatorTS`](./SM100TensorAccumulatorTS): * [​`TMemAccumulator`](./TMemAccumulator): * [​`TMemOperand`](./TMemOperand): * [​`UMMADescriptorSS`](./UMMADescriptorSS): * [​`UMMADescriptorTS`](./UMMADescriptorTS): ## Traits * [​`AccumulatorTile`](./AccumulatorTile): * [​`DescriptorPair`](./DescriptorPair): * [​`DescriptorPairTS`](./DescriptorPairTS): * [​`WriteableMMAOperandDescriptor`](./WriteableMMAOperandDescriptor): ## Functions * [​`local_tensor_type`](./local_tensor_type): * [​`mha_sm100_dispatch`](./mha_sm100_dispatch):
--- ## local_tensor_type
`local_tensor_type[dtype: DType, layout: Layout, element_layout: Layout]() -> LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## mha_sm100_dispatch
`mha_sm100_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, ScoreModType: ScoreModTrait, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig[dtype], group: Int, use_score_mod: Bool, ragged: Bool, sink: Bool, _is_cache_length_accurate: Bool](output: DeviceBuffer[output_type], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, num_rows_q: Int, mask: MaskType, score_mod: ScoreModType, valid_length: DeviceBuffer[DType.uint32], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: Float32, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]])`
--- ## ConsumerPipeline
`@register_passable(trivial)` `struct ConsumerPipeline[number_of_stages: Int]` ## Fields * ​mbar (`MBarType`): * ​state (`PipelineState[ConsumerPipeline[number_of_stages].num_stages]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_stages` `comptime num_stages = number_of_stages` ## Methods ### `__init__` `__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `producer_mbar` `producer_mbar(self) -> MBarType` **Returns:** `MBarType` ### `consumer_mbar` `consumer_mbar(self) -> MBarType` **Returns:** `MBarType` ### `wait` `wait(self)` ### `release` `release(mut self)` ### `step` `step(mut self)`
--- ## FA4Config
`@register_passable(trivial)` `struct FA4Config` ## Fields * ​MMA\_M (`Int`): * ​BM (`Int`): * ​BN (`Int`): * ​BK0 (`Int`): * ​BK1 (`Int`): * ​depth (`Int`): * ​padded\_depth (`Int`): * ​group (`Int`): * ​num\_q\_heads (`Int`): * ​num\_kv\_heads (`Int`): * ​TMEM\_S1 (`Int`): * ​TMEM\_O0 (`Int`): * ​TMEM\_O1 (`Int`): * ​TMEM\_P0 (`Int`): * ​TMEM\_P1 (`Int`): * ​TMEM\_C0 (`Int`): * ​TMEM\_C1 (`Int`): * ​tmem\_used (`Int`): * ​num\_kv\_stages (`Int`): * ​num\_mma\_stages (`Int`): * ​smem\_used (`Int`): * ​dtype\_size (`Int`): * ​split\_m (`Bool`): * ​swizzle\_mode (`TensorMapSwizzle`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `mbar_size` `comptime mbar_size = size_of[DType.int64]()` ### `MMA_K` `comptime MMA_K = 16` ### `num_correction_cols` `comptime num_correction_cols = 1` ### `num_threads` `comptime num_threads = 512` ### `sm100_smem_carveout` `comptime sm100_smem_carveout = (B200 - 1024)` ### `sm100_tmem_cols` `comptime sm100_tmem_cols = 512` ### `TMEM_S0` `comptime TMEM_S0 = 0` ## Methods ### `__init__` `__init__(*, num_q_heads: Int, group: Int, depth: Int, dtype_size: Int, swizzle_mode: TensorMapSwizzle, page_size: Int, is_mla: Bool = False) -> Self` ### `num_qo` `num_qo(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `supported` `supported(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `use_tmem_for_correction` `use_tmem_for_correction(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `correction_smem_elements` `correction_smem_elements(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `num_active_warps_per_group` `num_active_warps_per_group(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `num_active_threads_per_group` `num_active_threads_per_group(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## FA4MiscMBars
`@register_passable(trivial)` `struct FA4MiscMBars` ## Fields * ​mbar\_base (`MBarType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `C0_offset` `comptime C0_offset = 4` ### `C1_offset` `comptime C1_offset = 6` ### `order_offset` `comptime order_offset = 8` ### `Q1SyncIdx` `comptime Q1SyncIdx = 10` ### `S0_offset` `comptime S0_offset = 0` ### `S1_offset` `comptime S1_offset = 2` ### `size` `comptime size = 11` ## Methods ### `__init__` `__init__(mbar_base: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `init` `init(self)` ### `producer_s0` `producer_s0(self) -> ProducerPipeline[1]` **Returns:** `ProducerPipeline` ### `producer_s1` `producer_s1(self) -> ProducerPipeline[1]` **Returns:** `ProducerPipeline` ### `consumer_s` `consumer_s(self, wg_idx: UInt32) -> ConsumerPipeline[1]` **Returns:** `ConsumerPipeline` ### `consumer_c0` `consumer_c0(self) -> ConsumerPipeline[1]` **Returns:** `ConsumerPipeline` ### `consumer_c1` `consumer_c1(self) -> ConsumerPipeline[1]` **Returns:** `ConsumerPipeline` ### `producer_c` `producer_c(self, wg_idx: UInt32) -> ProducerPipeline[1]` **Returns:** `ProducerPipeline` ### `pipeline_order_wait` `pipeline_order_wait(self, wg_idx: UInt32) -> MBarType` **Returns:** `MBarType` ### `pipeline_order_arrive` `pipeline_order_arrive(self, wg_idx: UInt32) -> MBarType` **Returns:** `MBarType` ### `q1_wait_mbar` `q1_wait_mbar(self) -> ref [MutAnyOrigin, 3] SharedMemBarrier` **Returns:** `ref` ### `end` `end(self) -> MBarType` **Returns:** `MBarType`
--- ## KVConsumerPipeline
`@register_passable(trivial)` `struct KVConsumerPipeline[dtype: DType, config: FA4Config]` Pipeline for managing the consumption of K and V. This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations). We consume/produce in the following order: 0\. S0 <- Q0 @ Kn' 1\. O1 <- O1 + P1 @ V{n-1} 2\. S1 <- Q1 @ Kn' 3\. O0 <- O0 + P0 @ Vn Note that we have two MMA between calculating Si and consuming Pi, maximizing the overlap between MMAs and softmax calculation. Oi + Pi @ V also depends on the correction, which is computed asynchronously with the softmax in a correction warpgroup (as soon as the softmax writes the correction factor). # wait on K0 S0 <- Q0 @ K0' S1 <- Q1 @ K0' # release K0 # wait on V0 O0 <- P0 @ V0 for n in range(1,num\_iters): \# wait on Kn S0 <- Q0 @ Kn' O1 <- O1 + P1\@V{n-1} \# release V{n-1} S1 <- Q1 @ Kn' \# release Kn \# wait on Vn O0 <- P0 @ Vn O1 <- O1 + P1\@V{num\_iters-1} wK0, rK0, wV0 wK1, rV0, rK1, wV1 wK2, rV1, rK2, wV2 wK3, rV2, rK3, wV3 wKn(state) wK0(0), rK0(0), wV0(1) wK1(2), rV0(1), rK1(2), wV1(3) wK2(4), rV1(3), rK2(4), wV2(5) wK3(6), rV2(5), rK3(6), wV3(7) Rules: wK backs up and increments prior to waiting, except K0 rK increments after releasing rV uses backup wK0(0), rK0(0), wV0(1) wK1(2), rV0(1), rK1(2), wV1(3) wK2(4), rV1(3), rK2(4), wV2(5) rV2(5) ## Fields * ​kv\_pipeline (`KVPipeline[config.num_kv_stages, config.num_mma_stages]`): * ​k\_smem\_descriptor (`MMASmemDescriptorPair`): * ​v\_smem\_descriptor (`MMASmemDescriptorPair`): * ​v\_pipeline\_release\_index (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `full_kv_bytes` `comptime full_kv_bytes = ((config * config) * size_of[dtype]())` ### `mma_kv_bytes` `comptime mma_kv_bytes = ((config * config) * size_of[dtype]())` ## Methods ### `__init__` `__init__(kv_pipeline: KVPipeline[config.num_kv_stages, config.num_mma_stages], smem: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]) -> Self` `__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]) -> Self` ### `init` `init(self)` Only one of the producer or consumer should call `init()`. ### `wait` `wait[*, mma_stage: Int](self) -> UInt32` Wait on `k` from the producer, and return the `k` smem descriptor. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `wait_k` `wait_k[*, mma_stage: Int = (config - 1), pre_increment: Bool = True](mut self) -> MMASmemDescriptorPair` Wait on `k` from the producer, and return the `k` smem descriptor. If `pre-increment` is true. **Returns:** [`MMASmemDescriptorPair`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptorPair) ### `wait_v` `wait_v[*, mma_stage: Int = (config - 1)](self) -> MMASmemDescriptorPair` **Returns:** [`MMASmemDescriptorPair`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptorPair) ### `release_k` `release_k[*, mma_stage: Int = (config - 1)](mut self, e: Int32)` Must call `producer_commit` on the tmem resource before calling `consumer_release`. `release_k` does increment the pipeline step. ### `release_v` `release_v[*, mma_stage: Int = (config - 1)](self, e: Int32)` Must call `producer_commit` on the tmem resource before calling `consumer_release`. `release_v` does not increment the pipeline step.
--- ## KVPipeline
`@register_passable(trivial)` `struct KVPipeline[num_kv_stages: Int, num_mma_stages: Int]` KVPipeline has `num_kv_stages * num_mma_stages` stages. `num_kv_stages` refers to how many `K` and `V` tiles we pipeline for performing the `S = Q@K'` and `O += P@V` MMAs. Each of these MMAs is broken up into `num_mma_stages` pipelined MMAs. We set `step=False` for all but the last MMA that completes the operation. An alternative implementation would separate the two, and potentially allow for more overall stages at the cost of slightly more bookkeeping. ## Fields * ​mbar (`MBarType`): * ​state (`PipelineState[num_kv_stages]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_stages` `comptime num_stages = (num_kv_stages * num_mma_stages)` ## Methods ### `__init__` `__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `init` `init(self)` ### `producer_mbar` `producer_mbar[mma_stage: Int](self) -> MBarType` **Returns:** `MBarType` ### `consumer_mbar` `consumer_mbar[mma_stage: Int](self, idx: UInt32) -> MBarType` **Returns:** `MBarType` `consumer_mbar[mma_stage: Int](self) -> MBarType` **Returns:** `MBarType` ### `producer_acquire` `producer_acquire[mma_stage: Int = (num_mma_stages - 1)](self)` Returns the dynamic pipe idx. ### `consumer_wait` `consumer_wait[mma_stage: Int = (num_mma_stages - 1)](self)` ### `consumer_release` `consumer_release[mma_stage: Int = (num_mma_stages - 1)](mut self, e: Int32)` ### `num_mbars` `static num_mbars() -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## KVProducerPipeline
`@register_passable(trivial)` `struct KVProducerPipeline[dtype: DType, config: FA4Config]` ## Fields * ​kv\_pipeline (`KVPipeline[config.num_kv_stages, config.num_mma_stages]`): * ​smem (`KVProducerPipeline[dtype, config].SMemType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `KPairType` `comptime KPairType = TMADestination[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]()]` ### `KType` `comptime KType = LayoutTensor[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode](), MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]` ### `kv_bytes` `comptime kv_bytes = (KVProducerPipeline[dtype, config].kv_elements * size_of[dtype]())` ### `kv_elements` `comptime kv_elements = tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]().size()` ### `SMemType` `comptime SMemType = LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]` ### `VPairType` `comptime VPairType = TMADestination[dtype, tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode]()]` ### `VType` `comptime VType = LayoutTensor[dtype, tile_layout_mn_major[dtype, config.padded_depth, config.BK1, config.swizzle_mode](), MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]` ## Methods ### `__init__` `__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]) -> Self` `__init__(kv_pipeline: KVPipeline[config.num_kv_stages, config.num_mma_stages], smem: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]) -> Self` ### `init` `init(self)` Only one of the producer or consumer should call `init()`. ### `get_kv_smem` `get_kv_smem[*, mma_stage: Int](self) -> KVProducerPipeline[dtype, config].SMemType` **Returns:** `KVProducerPipeline` ### `get_k` `get_k[*, mma_stage: Int, expect: Bool = True](self) -> KVProducerPipeline[dtype, config].KPairType` **Returns:** `KVProducerPipeline` ### `get_v` `get_v[*, mma_stage: Int](self) -> KVProducerPipeline[dtype, config].VPairType` **Returns:** `KVProducerPipeline` ### `acquire_kv` `acquire_kv[*, mma_stage: Int = (config - 1)](self)` ### `commit_kv_step` `commit_kv_step(mut self)` Step the kv pipeline. The does not perform the commit on the mbars; that should be handled by the `tma_op.async_copy`.
--- ## MBarPipeline
`@register_passable(trivial)` `struct MBarPipeline[number_of_stages: Int]` ## Fields * ​mbar (`MBarType`): * ​state (`PipelineState[MBarPipeline[number_of_stages].num_stages]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_stages` `comptime num_stages = number_of_stages` ## Methods ### `__init__` `__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `init` `init[*, num_producer: UInt32 = 1, num_consumer: UInt32 = 1](self)` ### `num_mbars` `static num_mbars() -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## ProducerPipeline
`@register_passable(trivial)` `struct ProducerPipeline[number_of_stages: Int]` ## Fields * ​mbar (`MBarType`): * ​state (`PipelineState[ProducerPipeline[number_of_stages].num_stages]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `num_stages` `comptime num_stages = number_of_stages` ## Methods ### `__init__` `__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]) -> Self` ### `producer_mbar` `producer_mbar(self) -> MBarType` **Returns:** `MBarType` ### `consumer_mbar` `consumer_mbar(self) -> MBarType` **Returns:** `MBarType` ### `acquire` `acquire(self)` ### `commit` `commit(mut self)` ### `commit_mma` `commit_mma(self)` `commit_mma(self, elect: Int32)` ### `step` `step(mut self)`
--- ## SM100MHA2Q
`@register_passable(trivial)` `struct SM100MHA2Q[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, SchedulerType: MHATileScheduler, config: FA4Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, descriptor_shape: IndexList[3], remaining_global_dim_rank: Int]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_type` `comptime accum_type = get_accum_type[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type]()` ### `BM` `comptime BM = config.BM` ### `BN` `comptime BN = config.BN` ### `cta_group` `comptime cta_group = 1` ### `depth` `comptime depth = config.depth` ### `group` `comptime group = config.group` ### `k_bytes` `comptime k_bytes = SIMD[DType.uint32, 1]((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].swizzle_granularity * config)).__rmul__[DType.uint32, 1](SIMD[DType.uint32, 1](SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_dt_size))` ### `k_elements` `comptime k_elements = SIMD[DType.uint32, 1]((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].swizzle_granularity * config))` ### `KVPipelineType` `comptime KVPipelineType = KVPipeline[config.num_kv_stages, config.num_mma_stages]` ### `MMA_K` `comptime MMA_K = 16` ### `MMA_M` `comptime MMA_M = (config // 2)` ### `num_m_mmas` `comptime num_m_mmas = 2` ### `num_mma_stages` `comptime num_mma_stages = config.num_mma_stages` ### `num_q_heads` `comptime num_q_heads = config.num_q_heads` ### `OPipelineType` `comptime OPipelineType = MBarPipeline[2]` ### `padded_depth` `comptime padded_depth = config.padded_depth` ### `page_size` `comptime page_size = KVLUTType.page_size` ### `PositionType` `comptime PositionType = MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]` ### `qkv_dt_size` `comptime qkv_dt_size = size_of[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type]()` ### `qkv_type` `comptime qkv_type = KVLUTType.dtype` ### `qo_bytes` `comptime qo_bytes = SIMD[DType.uint32, 1]((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_dt_size * SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qo_elements))` ### `qo_elements` `comptime qo_elements = (SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].padded_depth * SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].MMA_M)` ### `ragged` `comptime ragged = ValidLengthType.is_null.__bool__().__invert__()` ### `simd_size` `comptime simd_size = simd_width_of[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type]()` ### `swizzle_granularity` `comptime swizzle_granularity = (config.swizzle_mode.bytes() // SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_dt_size)` ### `UMMA0Type` `comptime UMMA0Type = SM100TensorAccumulatorSS[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].MMA_M, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].BN, align_up(SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].depth, 16), swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].num_mma_stages]` ### `UMMA1Type` `comptime UMMA1Type = SM100TensorAccumulatorTS[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].MMA_M, config.padded_depth, SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].BN, config.swizzle_mode, transpose_b=False, num_stages=SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].num_mma_stages]` ### `v_bytes_per_mma` `comptime v_bytes_per_mma = SIMD[DType.uint32, 1](((SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_dt_size * 16) * config))` ## Methods ### `kernel` `static kernel(q_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False](), config, True), _ragged_desc_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False](), config)], k_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.depth, Tuple[]()), config)], v_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.depth, Tuple[]()), config)], o_ptr_arg: LegacyUnsafePointer[Scalar[output_type]], ragged_tma_store: RaggedTensorMap[output_type, descriptor_shape, remaining_global_dim_rank, config.swizzle_mode], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, ScoreModType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])` ### `mask_status` `static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `scale_write_output` `static scale_write_output(local_row: UInt32, inv_row_sum: Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type], o_smem: LegacyUnsafePointer[Scalar[output_type], address_space=AddressSpace.SHARED], o_tmem: TMemTile[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type, (SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].BM // 2), SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].padded_depth], o_ptr: LegacyUnsafePointer[Scalar[output_type]], ragged_tma_store: RaggedTensorMap[output_type, descriptor_shape, remaining_global_dim_rank, config.swizzle_mode], warp_idx: UInt32, consumer_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], current_seq: Int, num_output_rows: Int32)` ### `softmax` `static softmax(tmem_addr: UInt32, warp_idx: UInt32, mbars: FA4MiscMBars, o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], score_row: UInt32, seq_info: SeqInfo, mask: MaskType, num_keys: UInt32, scale: Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type], score_mod: ScoreModType, max_seq_len: UInt32, o_ptr_arg: LegacyUnsafePointer[Scalar[output_type]], ragged_tma_store: RaggedTensorMap[output_type, descriptor_shape, remaining_global_dim_rank, config.swizzle_mode], o_smem: LegacyUnsafePointer[Scalar[output_type], address_space=AddressSpace.SHARED], sink_weights: SinkType)` ### `correction` `static correction(tmem_addr: UInt32, mbars: FA4MiscMBars, o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], score_row: UInt32, num_keys: UInt32, mask: MaskType)` ### `load` `static load(mbars: FA4MiscMBars, kv_pipeline_arg: KVPipeline[config.num_kv_stages, config.num_mma_stages], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False](), config, True), _ragged_desc_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.depth, decoding=False](), config)], k_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.depth, Tuple[]()), config)], v_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, config.depth, Tuple[]()), config)], kv_lut: KVLUTType, q_smem: LegacyUnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace.SHARED])` ### `descriptor_q` `static descriptor_q(q_smem: LegacyUnsafePointer[Scalar[SM100MHA2Q[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type], address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair` **Returns:** [`MMASmemDescriptorPair`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptorPair) ### `mma` `static mma(tmem_addr: UInt32, mbars: FA4MiscMBars, kv_pipeline_arg: KVPipeline[config.num_kv_stages, config.num_mma_stages], o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], score_row: UInt32, num_keys: UInt32, mask: MaskType, q_smem: LegacyUnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace.SHARED])`
--- ## SM100TensorAccumulatorSS (Mha_sm100_2q)
`@register_passable(trivial)` `struct SM100TensorAccumulatorSS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BK: Int, *, swizzle_a: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, transpose_b: Bool = True, cta_group: Int = 1, num_stages: Int = 1]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `a_layout` `comptime a_layout = tile_layout_k_major[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, align_up(MMA_M, 8), SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_a]()` ### `accum_t` `comptime accum_t = accum_type` ### `AType` `comptime AType = MMASmemDescriptorPair` ### `b_layout` `comptime b_layout = tile_layout_k_major[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, MMA_N, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_b]() if transpose_b else tile_layout_mn_major[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, MMA_N, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK, swizzle_b]()` ### `BType` `comptime BType = MMASmemDescriptorPair` ### `CType` `comptime CType = TMemTile[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].accum_t, MMA_M, MMA_N]` ### `idesc` `comptime idesc = UMMAInsDescriptor.create[UMMAKind.KIND_F16, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].accum_t, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t, Index[dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()` ### `MMA_K` `comptime MMA_K = 16` ### `num_k_blocks` `comptime num_k_blocks = (SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].padded_BK // 16)` ### `num_k_blocks_per_stage` `comptime num_k_blocks_per_stage = (SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].num_k_blocks // num_stages)` ### `num_k_mmas` `comptime num_k_mmas = ceildiv(BK, 16)` ### `operand_size` `comptime operand_size = size_of[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t]()` ### `operand_t` `comptime operand_t = operand_type` ### `padded_BK` `comptime padded_BK = align_up(BK, SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].swizzle_granularity)` ### `swizzle_granularity` `comptime swizzle_granularity = (max(swizzle_a.bytes(), swizzle_b.bytes()) // size_of[SM100TensorAccumulatorSS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_a=swizzle_a, swizzle_b=swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages].operand_t]())` ## Methods ### `mma` `static mma[*, stage_idx: Int = 0](a: MMASmemDescriptorPair, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)`
--- ## SM100TensorAccumulatorTS (Mha_sm100_2q)
`@register_passable(trivial)` `struct SM100TensorAccumulatorTS[operand_type: DType, accum_type: DType, MMA_M: Int, MMA_N: Int, BK: Int, swizzle_b: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B, *, transpose_b: Bool = True, cta_group: Int = 1, num_stages: Int = 1, padded_BK: Int = BK]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_t` `comptime accum_t = accum_type` ### `AType` `comptime AType = TMemTile[operand_type, MMA_M, BK]` ### `b_layout` `comptime b_layout = tile_layout_k_major[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_t, MMA_N, BK, swizzle_b]() if transpose_b else tile_layout_mn_major[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_t, MMA_N, BK, swizzle_b]()` ### `BType` `comptime BType = MMASmemDescriptorPair` ### `CType` `comptime CType = TMemTile[SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].accum_t, MMA_M, MMA_N]` ### `idesc` `comptime idesc = UMMAInsDescriptor.create[UMMAKind.KIND_F16, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].accum_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_t, SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_t, Index[dtype=DType.uint32](MMA_M, MMA_N), transpose_b=transpose_b]()` ### `MMA_K` `comptime MMA_K = 16` ### `num_k_blocks` `comptime num_k_blocks = (padded_BK // 16)` ### `num_k_blocks_per_stage` `comptime num_k_blocks_per_stage = (SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].num_k_blocks // num_stages)` ### `num_k_mmas` `comptime num_k_mmas = (BK // 16)` ### `operand_size` `comptime operand_size = size_of[operand_type]()` ### `operand_t` `comptime operand_t = operand_type` ### `swizzle_granularity` `comptime swizzle_granularity = (swizzle_b.bytes() // SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].operand_size)` ## Methods ### `descriptor_a` `static descriptor_a(a_tmem: UInt32) -> SM100TensorAccumulatorTS[operand_type, accum_type, MMA_M, MMA_N, BK, swizzle_b, transpose_b=transpose_b, cta_group=cta_group, num_stages=num_stages, padded_BK=padded_BK].AType` **Returns:** `SM100TensorAccumulatorTS` ### `mma` `static mma[*, stage_idx: Int = 0](a: UInt32, b: MMASmemDescriptorPair, c: UInt32, *, c_scale: UInt32, elect: Int32)`
--- ## STMatrixLayout
`@register_passable(trivial)` `struct STMatrixLayout[BM: Int, BN: Int, *, num_threads: Int, accum_type_size: Int]` Layout for using `st_matrix` for writing the final accumulator to smem. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `bits` `comptime bits = (64 * accum_type_size)` ### `bits_per_byte` `comptime bits_per_byte = 8` ### `element_layout` `comptime element_layout = Layout.row_major(1, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].frag_simdwidth)` ### `elements_per_repeat` `comptime elements_per_repeat = 4` ### `frag_simdwidth` `comptime frag_simdwidth = 2` ### `frag_size` `comptime frag_size = ((BN * 2) // 4)` ### `num_m_tiles` `comptime num_m_tiles = (STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].num_m_tiles_total // STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].num_warpgroups)` ### `num_m_tiles_total` `comptime num_m_tiles_total = ceildiv((2 * BM), 128)` ### `num_row_blocks_per_mma` `comptime num_row_blocks_per_mma = 2` ### `num_warpgroups` `comptime num_warpgroups = ceildiv(num_threads, 128)` ### `repeat` `comptime repeat = (BN // 8)` ### `row_of_frags_layout` `comptime row_of_frags_layout = Layout.row_major(STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].num_m_tiles, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].frag_size)` ### `TensorType` `comptime TensorType[dtype: DType] = LayoutTensor[dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].element_layout]` #### Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): ### `thread_cols` `comptime thread_cols = 4` ### `vec_local_layout` `comptime vec_local_layout = Layout(IntTuple(IntTuple(2, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].num_m_tiles), IntTuple(STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].repeat), Tuple[]()), IntTuple(IntTuple(STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].frag_simdwidth, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size].frag_size), IntTuple(4), Tuple[]()))` ## Methods ### `__init__` `__init__() -> Self`
--- ## STMatrixOffsets
`@register_passable(trivial)` `struct STMatrixOffsets[BM: Int, BN: Int, *, num_threads: Int, accum_type_size: Int, curr_repeat: Int, cumulative_repeat: Int, m_mma: Int]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `b32_per_repeat` `comptime b32_per_repeat = ((STMatrixOffsets[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size, curr_repeat=curr_repeat, cumulative_repeat=cumulative_repeat, m_mma=m_mma].STLayout.elements_per_repeat * accum_type_size) // 4)` ### `local_frag_size_b32` `comptime local_frag_size_b32 = (curr_repeat * STMatrixOffsets[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size, curr_repeat=curr_repeat, cumulative_repeat=cumulative_repeat, m_mma=m_mma].b32_per_repeat)` ### `ptr_offset` `comptime ptr_offset = (STMatrixOffsets[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size, curr_repeat=curr_repeat, cumulative_repeat=cumulative_repeat, m_mma=m_mma].b32_per_repeat * ((STMatrixOffsets[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size, curr_repeat=curr_repeat, cumulative_repeat=cumulative_repeat, m_mma=m_mma].STLayout.repeat * m_mma) + cumulative_repeat))` ### `STLayout` `comptime STLayout = STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size]` ### `tmem_col_offset` `comptime tmem_col_offset = ((cumulative_repeat * STMatrixOffsets[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size, curr_repeat=curr_repeat, cumulative_repeat=cumulative_repeat, m_mma=m_mma].STLayout.frag_simdwidth) * 4)` ### `tmem_offset` `comptime tmem_offset = ((STMatrixOffsets[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size, curr_repeat=curr_repeat, cumulative_repeat=cumulative_repeat, m_mma=m_mma].tmem_row_offset << 16) + STMatrixOffsets[BM, BN, num_threads=num_threads, accum_type_size=accum_type_size, curr_repeat=curr_repeat, cumulative_repeat=cumulative_repeat, m_mma=m_mma].tmem_col_offset)` ### `tmem_row_offset` `comptime tmem_row_offset = (16 * m_mma)` ## Methods ### `__init__` `__init__() -> Self`
--- ## TMADestination
`@register_passable(trivial)` `struct TMADestination[dtype: DType, layout: Layout]` ## Fields * ​mbar (`MBarType`): * ​smem (`LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem: LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]) -> Self` ### `split_smem` `split_smem[first: Layout, second: Layout](self) -> Tuple[LayoutTensor[dtype, first, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128], LayoutTensor[dtype, second, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)
--- ## TMemTile
`@register_passable(trivial)` `struct TMemTile[dtype_: DType, BM: Int, BN: Int]` ## Fields * ​tmem\_addr (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `dtype` `comptime dtype = dtype_` ### `dtype_size` `comptime dtype_size = size_of[TMemTile[dtype_, BM, BN].dtype]()` ### `num_m_tiles` `comptime num_m_tiles = (BM // 64)` ## Methods ### `__init__` `__init__(tmem_addr: UInt32) -> Self` ### `__getitem__` `__getitem__(self, i: UInt32) -> Self` ### `offset` `offset[m_mma: Int, n_mma: Int](self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `allocate_register_tile` `static allocate_register_tile[*, num_threads: Int]() -> LayoutTensor[TMemTile[dtype_, BM, BN].dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=TMemTile[dtype_, BM, BN].dtype_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=TMemTile[dtype_, BM, BN].dtype_size].element_layout]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `store_async` `store_async[*, num_threads: Int](self, src: LayoutTensor[TMemTile[dtype_, BM, BN].dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=TMemTile[dtype_, BM, BN].dtype_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=TMemTile[dtype_, BM, BN].dtype_size].element_layout])` `store_async[src_type: DType](self, src: LayoutTensor[src_type, Layout.row_major(BN), MutAnyOrigin, address_space=AddressSpace.LOCAL])` ### `store` `store[*, num_threads: Int](self, src: LayoutTensor[TMemTile[dtype_, BM, BN].dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=TMemTile[dtype_, BM, BN].dtype_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=TMemTile[dtype_, BM, BN].dtype_size].element_layout])` `store[src_type: DType](self, src: LayoutTensor[src_type, Layout.row_major(BN), MutAnyOrigin, address_space=AddressSpace.LOCAL])` ### `load_async_with_st_matrix_layout` `load_async_with_st_matrix_layout[*, num_threads: Int](self) -> LayoutTensor[TMemTile[dtype_, BM, BN].dtype, STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=TMemTile[dtype_, BM, BN].dtype_size].vec_local_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=STMatrixLayout[BM, BN, num_threads=num_threads, accum_type_size=TMemTile[dtype_, BM, BN].dtype_size].element_layout]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor) ### `load_async` `load_async(self) -> LayoutTensor[TMemTile[dtype_, BM, BN].dtype, Layout.row_major(BN), MutAnyOrigin, address_space=AddressSpace.LOCAL]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## apply_mask
`apply_mask[dtype: DType, BN: Int, MaskType: MHAMask, ScoreModType: ScoreModTrait, //, *, use_score_mod: Bool, masked: Bool, last_iter: Bool, decoding: Bool = False](srow: LayoutTensor[dtype, Layout.row_major(BN), MutAnyOrigin, address_space=AddressSpace.LOCAL], mask: MaskType, score_mod: ScoreModType, scale_log2e: Scalar[dtype], *, prompt_idx: UInt32, q_head_idx: UInt32, kv_tile_start_row: UInt32, max_seq_len: UInt32, num_keys: UInt32, score_row: UInt32)`
--- ## break_into_powers_of_two
`break_into_powers_of_two[origins: OriginSet, //, func: fn[pow_two: Int, offset: Int]() capturing -> None, N: Int, *, max_value: Int = 128]()`
--- ## build_mma_ss
`build_mma_ss(kind: String, layout_a: Layout, layout_b: Layout, *, operand_size: Int, num_k_mmas: Int) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String)
--- ## build_mma_ts
`build_mma_ts(kind: String, layout_b: Layout, *, operand_size: Int, num_k_mmas: Int) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String)
--- ## bulk_mma
`bulk_mma[kind: UMMAKind, //, layout_a: Layout, layout_b: Layout, *, num_k_mmas: Int, operand_size: Int](idesc: UMMAInsDescriptor[kind], a: MMASmemDescriptorPair, b: MMASmemDescriptorPair, c_tmem: UInt32, c_scale: UInt32, elect: Int32)` `bulk_mma[kind: UMMAKind, //, layout_b: Layout, *, num_k_mmas: Int, operand_size: Int](idesc: UMMAInsDescriptor[kind], a: UInt32, b: MMASmemDescriptorPair, c_tmem: UInt32, c_scale: UInt32, elect: Int32)`
--- ## cumulative_power_of_two
`cumulative_power_of_two(N: Int, i: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## elect
`elect() -> Int32` **Returns:** [`Int32`](/mojo/stdlib/builtin/simd/#int32)
--- ## elect_mma_arrive
`elect_mma_arrive[cta_group: Int = 1](mbar_ptr: LegacyUnsafePointer[type, address_space=AddressSpace.SHARED, mut=mut, origin=origin], elect: Int32)` Arrive at the mbar pointer for the MMA instruction. **Parameters:** * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of ctas used by MMA. **Args:** * ​mbar\_ptr (`LegacyUnsafePointer`): Pointer to the mbar. * ​elect ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): `elect()`.
--- ## extract_power_of_two
`extract_power_of_two(N: Int, i: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## mha_sm100_2q
## `comptime` values ### `LocalTensor` `comptime LocalTensor[dtype: DType, layout: Layout, element_layout: Layout = Layout(IntTuple(1), IntTuple(1))] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.LOCAL, element_layout=element_layout]` #### Parameters * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): * ​element\_layout ([`Layout`](/kernels/layout/layout/Layout)): ### `logger` `comptime logger = Logger[DEFAULT_LEVEL](stdout, "", False)` ### `MBarType` `comptime MBarType = LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED]` ### `SharedMemPointer` `comptime SharedMemPointer[type: AnyType] = LegacyUnsafePointer[type, address_space=AddressSpace.SHARED]` #### Parameters * ​type ([`AnyType`](/stdlib/builtin/anytype/AnyType)): ### `SharedMemTensor` `comptime SharedMemTensor[dtype: DType, layout: Layout] = LayoutTensor[dtype, layout, MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]` #### Parameters * ​dtype ([`DType`](/stdlib/builtin/dtype/DType)): * ​layout ([`Layout`](/kernels/layout/layout/Layout)): ## Structs * [​`ConsumerPipeline`](./ConsumerPipeline): * [​`FA4Config`](./FA4Config): * [​`FA4MiscMBars`](./FA4MiscMBars): * [​`KVConsumerPipeline`](./KVConsumerPipeline): Pipeline for managing the consumption of K and V. This follows the order of Tri Dao and Cutlass implementations (modulo any rotation of the ops through the iterations). * [​`KVPipeline`](./KVPipeline): KVPipeline has `num_kv_stages * num_mma_stages` stages. `num_kv_stages` refers to how many `K` and `V` tiles we pipeline for performing the `S = Q@K'` and `O += P@V` MMAs. Each of these MMAs is broken up into `num_mma_stages` pipelined MMAs. We set `step=False` for all but the last MMA that completes the operation. An alternative implementation would separate the two, and potentially allow for more overall stages at the cost of slightly more bookkeeping. * [​`KVProducerPipeline`](./KVProducerPipeline): * [​`MBarPipeline`](./MBarPipeline): * [​`ProducerPipeline`](./ProducerPipeline): * [​`SM100MHA2Q`](./SM100MHA2Q): * [​`SM100TensorAccumulatorSS`](./SM100TensorAccumulatorSS): * [​`SM100TensorAccumulatorTS`](./SM100TensorAccumulatorTS): * [​`STMatrixLayout`](./STMatrixLayout): Layout for using `st_matrix` for writing the final accumulator to smem. * [​`STMatrixOffsets`](./STMatrixOffsets): * [​`TMADestination`](./TMADestination): * [​`TMemTile`](./TMemTile): ## Functions * [​`apply_mask`](./apply_mask): * [​`break_into_powers_of_two`](./break_into_powers_of_two): * [​`build_mma_ss`](./build_mma_ss): * [​`build_mma_ts`](./build_mma_ts): * [​`bulk_mma`](./bulk_mma): * [​`cumulative_power_of_two`](./cumulative_power_of_two): * [​`elect`](./elect): * [​`elect_mma_arrive`](./elect_mma_arrive): Arrive at the mbar pointer for the MMA instruction. * [​`extract_power_of_two`](./extract_power_of_two): * [​`maximum`](./maximum): * [​`mha_sm100_dispatch`](./mha_sm100_dispatch): * [​`sum`](./sum):
--- ## maximum
`maximum[dtype: DType, BN: Int, //, *, width: Int = 8](x: LayoutTensor[dtype, Layout.row_major(BN), MutAnyOrigin, address_space=AddressSpace.LOCAL]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) `maximum[dtype: DType, BN: Int, width: Int, //](x: LayoutTensor[dtype, Layout.row_major(BN), MutAnyOrigin, address_space=AddressSpace.LOCAL], init: SIMD[dtype, width]) -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## mha_sm100_dispatch (Mha_sm100_2q)
`mha_sm100_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, ScoreModType: ScoreModTrait, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig[dtype], group: Int, use_score_mod: Bool, ragged: Bool, sink: Bool, _is_cache_length_accurate: Bool](output: DeviceBuffer[output_type], q_arg: LegacyUnsafePointer[Scalar[q_type]], k: KVType, v: KVType, num_rows_q: Int, mask: MaskType, score_mod: ScoreModType, valid_length: LegacyUnsafePointer[UInt32], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: Float32, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]])`
--- ## sum (Mha_sm100_2q)
`sum[dtype: DType, BN: Int, //, *, width: Int = 8](x: LayoutTensor[dtype, Layout.row_major(BN), MutAnyOrigin, address_space=AddressSpace.LOCAL]) -> SIMD[dtype, 2]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## mha_sm90
## Functions * [​`mha_sm90_dispatch`](./mha_sm90_dispatch):
--- ## mha_sm90_dispatch
`mha_sm90_dispatch[q_type: DType, KVType: MHAOperand, MaskType: MHAMask, ScoreModType: ScoreModTrait, output_type: DType, MaxPromptLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, //, config: MHAConfig[dtype], group: Int, use_score_mod: Bool, ragged: Bool, sink: Bool, _is_cache_length_accurate: Bool](output: DeviceBuffer[output_type], q_arg: DeviceBuffer[q_type], k: KVType, v: KVType, num_rows_q: Int, mask_functor: MaskType, score_mod: ScoreModType, valid_length: DeviceBuffer[DType.uint32], max_prompt_len_arg: MaxPromptLenType, max_cache_valid_length_arg: Int, scale: Float32, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]], batch_size_arg: Int, partition: PartitionType, ctx: DeviceContext, sink_weights: OptionalReg[LayoutTensor[q_type, Layout.row_major(-1), MutAnyOrigin]])`
--- ## MHASchedule
`@register_passable(trivial)` `struct MHASchedule` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `DEFAULT` `comptime DEFAULT = MHASchedule(0)` ### `PROMPT_ROTATE` `comptime PROMPT_ROTATE = MHASchedule(1)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## MHASchedulerSynchronization
`@register_passable(trivial)` `struct MHASchedulerSynchronization` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ALL` `comptime ALL = MHASchedulerSynchronization(2)` ### `DEFAULT` `comptime DEFAULT = MHASchedulerSynchronization.PRODUCER` ### `NONE` `comptime NONE = MHASchedulerSynchronization(0)` ### `PRODUCER` `comptime PRODUCER = MHASchedulerSynchronization(1)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## MHATileScheduler
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ### `may_advance` `comptime may_advance` ### `mha_schedule` `comptime mha_schedule` The MHATileScheduler trait describes a schedule for the persistent kernel. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `get_current_work_info` `get_current_work_info[ValidLengthType: OptionalPointer, //](self: _Self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo` Returns the current `WorkInfo`. **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) ### `advance` `advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self: _Self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]` Advance state to the next work item. `func` must return a `Bool` indicating whether there is more work. Returns `True` if there is more work. **Returns:** [`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg) ### `grid_dim` `static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]` Return the grid\_dim required for the kernel. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `initial_state` `initial_state[ValidLengthType: OptionalPointer, //](self: _Self, ptr: LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState` Create the initial state object. **Returns:** [`MHATileState`](/mojo/kernels/nn/mha_tile_scheduler/MHATileState) ### `unsafe_seq_info` `unsafe_seq_info[ValidLengthType: OptionalPointer, //](self: _Self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo` **Returns:** [`SeqInfo`](/mojo/kernels/nn/mha_tile_scheduler/SeqInfo) ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## MHATileState
`@register_passable(trivial)` `struct MHATileState` ## Fields * ​idx (`UInt32`): * ​sidx\_ptr (`LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED]`): * ​max\_idx (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(idx: UInt32, sidx_ptr: LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED], max_idx: UInt32) -> Self` ### `is_valid` `is_valid(self, idx: UInt32) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## MHATileSummary
`@register_passable(trivial)` `struct MHATileSummary[ValidLengthType: OptionalPointer]` ## Fields * ​batch\_size (`UInt32`): * ​max\_num\_prompt\_tiles (`UInt32`): * ​valid\_length (`ValidLengthType`): * ​max\_seq\_len (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True if ValidLengthType.__copyinit__is_trivial else ValidLengthType.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = True if ValidLengthType.__del__is_trivial else ValidLengthType.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True if ValidLengthType.__moveinit__is_trivial else ValidLengthType.__moveinit__is_trivial` ## Methods ### `__init__` `__init__(batch_size: UInt32, max_num_prompt_tiles: UInt32, valid_length: ValidLengthType, max_seq_len: UInt32) -> Self` ### `get_current_work_info` `get_current_work_info[tile_shape: UInt32, num_heads: UInt32, schedule: MHASchedule](self, idx: UInt32) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) `get_current_work_info[tile_shape: UInt32, num_heads: UInt32, schedule: MHASchedule](self, idx: MHATileState) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) ### `unsafe_get_current_work_info` `unsafe_get_current_work_info[tile_shape: UInt32, num_heads: UInt32, schedule: MHASchedule](self, idx: UInt32) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) ### `max_idx` `max_idx(self, num_heads: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `grid_dim` `static grid_dim[num_heads: UInt32](max_num_prompt_tiles: UInt32, batch_size: UInt32) -> Tuple[Int, Int, Int]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `seq_info` `seq_info(self, work: WorkInfo) -> SeqInfo` **Returns:** [`SeqInfo`](/mojo/kernels/nn/mha_tile_scheduler/SeqInfo) ### `unsafe_seq_info` `unsafe_seq_info[tile_shape: UInt32, num_heads: UInt32, schedule: MHASchedule](self, idx: UInt32) -> SeqInfo` **Returns:** [`SeqInfo`](/mojo/kernels/nn/mha_tile_scheduler/SeqInfo) `unsafe_seq_info[tile_shape: UInt32, num_heads: UInt32, schedule: MHASchedule](self, state: MHATileState) -> SeqInfo` **Returns:** [`SeqInfo`](/mojo/kernels/nn/mha_tile_scheduler/SeqInfo)
--- ## QueuedTileScheduler
`@register_passable(trivial)` `struct QueuedTileScheduler[tile_shape: UInt32, num_heads: UInt32, /, decoding: Bool, num_ctas: UInt32 = H100.sm_count, schedule: MHASchedule = MHASchedule.DEFAULT]` If `decoding == False`, then `num_heads` is `q_num_heads`. If `decoding == True`, then `num_heads` is `kv_num_heads`. ## Fields * ​gidx\_ptr (`LegacyUnsafePointer[UInt32, address_space=AddressSpace.GLOBAL]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHATileScheduler`](/mojo/kernels/nn/mha_tile_scheduler/MHATileScheduler), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = QueuedTileScheduler[tile_shape, num_heads, decoding, num_ctas, schedule]` ### `may_advance` `comptime may_advance = True` ### `mha_schedule` `comptime mha_schedule = schedule` ## Methods ### `__init__` `__init__(gidx_ptr: LegacyUnsafePointer[UInt32]) -> Self` ### `get_current_work_info` `get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) ### `advance` `advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]` The parameter `func` must return a `Bool` indicating whether the `WorkInfo` arg is valid. This function returns whether the current idx corresponds to a valid `WorkInfo`. Note that if `MHASchedulerSynchronization` is `NONE`, then we assume it is only called by `thread_idx.x==0`. **Returns:** [`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg) ### `grid_dim` `static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `initial_state` `initial_state[ValidLengthType: OptionalPointer, //](self, ptr: LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState` **Returns:** [`MHATileState`](/mojo/kernels/nn/mha_tile_scheduler/MHATileState) ### `unsafe_seq_info` `unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo` **Returns:** [`SeqInfo`](/mojo/kernels/nn/mha_tile_scheduler/SeqInfo) ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name.
--- ## SeqInfo
`@register_passable(trivial)` `struct SeqInfo` ## Fields * ​seq\_len (`UInt32`): * ​start\_of\_seq (`UInt32`): * ​prompt\_offset (`UInt32`): * ​head\_idx (`UInt32`): * ​prompt\_idx (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(seq_len: UInt32, start_of_seq: UInt32, work: WorkInfo) -> Self` ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `create` `static create[ValidLengthType: OptionalPointer, //](work: WorkInfo, valid_length: ValidLengthType, max_seq_len: UInt32) -> Self`
--- ## TileScheduler (Mha_tile_scheduler)
`@register_passable(trivial)` `struct TileScheduler[tile_shape: UInt32, num_heads: UInt32, /, num_ctas: UInt32 = H100.sm_count, schedule: MHASchedule = MHASchedule.DEFAULT]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHATileScheduler`](/mojo/kernels/nn/mha_tile_scheduler/MHATileScheduler), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = TileScheduler[tile_shape, num_heads, num_ctas, schedule]` ### `may_advance` `comptime may_advance = True` ### `mha_schedule` `comptime mha_schedule = schedule` ## Methods ### `__init__` `__init__() -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_current_work_info` `get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) ### `fetch_next_work` `fetch_next_work(self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) ### `advance` `advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]` **Returns:** [`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg) ### `grid_dim` `static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `initial_state` `initial_state[ValidLengthType: OptionalPointer, //](self, ptr: LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState` **Returns:** [`MHATileState`](/mojo/kernels/nn/mha_tile_scheduler/MHATileState) ### `unsafe_seq_info` `unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo` **Returns:** [`SeqInfo`](/mojo/kernels/nn/mha_tile_scheduler/SeqInfo)
--- ## TransientScheduler
`@register_passable(trivial)` `struct TransientScheduler[tile_shape: UInt32, num_heads: UInt32]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHATileScheduler`](/mojo/kernels/nn/mha_tile_scheduler/MHATileScheduler), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = TransientScheduler[tile_shape, num_heads]` ### `may_advance` `comptime may_advance = False` ### `mha_schedule` `comptime mha_schedule = MHASchedule.DEFAULT` ## Methods ### `__init__` `__init__() -> Self` ### `get_type_name` `static get_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_device_type_name` `static get_device_type_name() -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `get_current_work_info` `get_current_work_info(self) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) `get_current_work_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> WorkInfo` **Returns:** [`WorkInfo`](/mojo/kernels/nn/mha_tile_scheduler/WorkInfo) ### `advance` `advance[ValidLengthType: OptionalPointer, //, producer: Bool, sync: MHASchedulerSynchronization = MHASchedulerSynchronization.DEFAULT](self, ts: MHATileSummary[ValidLengthType], mut state: MHATileState, pipeline_idx: UInt32) -> OptionalReg[SeqInfo]` **Returns:** [`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg) ### `grid_dim` `static grid_dim(batch_size: UInt32, max_num_prompt_tiles: UInt32) -> Tuple[Int, Int, Int]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `initial_state` `initial_state[ValidLengthType: OptionalPointer, //](self, ptr: LegacyUnsafePointer[UInt32, address_space=AddressSpace.SHARED], tile_summary: MHATileSummary[ValidLengthType]) -> MHATileState` **Returns:** [`MHATileState`](/mojo/kernels/nn/mha_tile_scheduler/MHATileState) ### `unsafe_seq_info` `unsafe_seq_info[ValidLengthType: OptionalPointer, //](self, ts: MHATileSummary[ValidLengthType], state: MHATileState) -> SeqInfo` **Returns:** [`SeqInfo`](/mojo/kernels/nn/mha_tile_scheduler/SeqInfo)
--- ## WorkInfo (Mha_tile_scheduler)
`@register_passable(trivial)` `struct WorkInfo` ## Fields * ​prompt\_offset (`UInt32`): * ​head\_idx (`UInt32`): * ​prompt\_idx (`UInt32`): * ​is\_valid\_tile (`Bool`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `is_valid` `is_valid(self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## mha_tile_scheduler
## Structs * [​`MHASchedule`](./MHASchedule): * [​`MHASchedulerSynchronization`](./MHASchedulerSynchronization): * [​`MHATileState`](./MHATileState): * [​`MHATileSummary`](./MHATileSummary): * [​`QueuedTileScheduler`](./QueuedTileScheduler): If `decoding == False`, then `num_heads` is `q_num_heads`. If `decoding == True`, then `num_heads` is `kv_num_heads`. * [​`SeqInfo`](./SeqInfo): * [​`TileScheduler`](./TileScheduler): * [​`TransientScheduler`](./TransientScheduler): * [​`WorkInfo`](./WorkInfo): ## Traits * [​`MHATileScheduler`](./MHATileScheduler):
--- ## DynamicInt
`@register_passable(trivial)` `struct DynamicInt` ## Fields * ​value (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`OptionallyStaticInt`](/mojo/kernels/nn/mha_utils/OptionallyStaticInt), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `static_value` `comptime static_value = OptionalReg[Int](None)` ## Methods ### `__init__` `__init__(value: Int) -> Self` ### `__int__` `__int__(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `as_uint32` `as_uint32(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## FlashAttentionAlgorithm
`@register_passable(trivial)` `struct FlashAttentionAlgorithm` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `FLASH_ATTENTION_1` `comptime FLASH_ATTENTION_1 = FlashAttentionAlgorithm(1)` ### `FLASH_ATTENTION_2` `comptime FLASH_ATTENTION_2 = FlashAttentionAlgorithm(2)` ### `FLASH_ATTENTION_3` `comptime FLASH_ATTENTION_3 = FlashAttentionAlgorithm(3)` ### `NAIVE` `comptime NAIVE = FlashAttentionAlgorithm(0)` ## Methods ### `__init__` `__init__() -> Self` `__init__(value: Int) -> Self` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) `__eq__(self, version: Int) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `init` `init(self, dtype: DType) -> Self` ### `write_to` `write_to(self, mut writer: T)`
--- ## MHAConfig
`@register_passable(trivial)` `struct MHAConfig[dtype: DType]` ## Fields * ​num\_heads (`UInt`): * ​depth (`UInt`): * ​padded\_depth (`UInt`): * ​num\_queries\_per\_block (`UInt`): * ​num\_keys\_per\_block (`UInt`): * ​BK (`UInt`): * ​WM (`UInt`): * ​WN (`UInt`): * ​num\_pipeline\_stages (`UInt`): * ​k\_group\_size (`UInt`): * ​algorithm (`FlashAttentionAlgorithm`): * ​swizzle\_mode (`TensorMapSwizzle`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(num_heads: UInt, depth: UInt, num_queries_per_block: OptionalReg[UInt] = None, num_keys_per_block: OptionalReg[UInt] = None, BK: OptionalReg[UInt] = None, WM: OptionalReg[UInt] = None, WN: OptionalReg[UInt] = None, num_pipeline_stages: UInt = 4, k_group_size: UInt = 1, algorithm: FlashAttentionAlgorithm = FlashAttentionAlgorithm(-1), swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_128B) -> Self` ### `block_m` `block_m(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `block_n` `block_n(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `block_k` `block_k(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `warp_m` `warp_m(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `warp_n` `warp_n(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `num_warps_m` `num_warps_m(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `num_warps_n` `num_warps_n(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `num_consumer_threads` `num_consumer_threads(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `num_producer_threads` `num_producer_threads[producer_consumer_kernel: Bool = False](self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `num_threads` `num_threads[producer_consumer_kernel: Bool = False](self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `swizzle_granularity` `swizzle_granularity(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `q_smem_size` `q_smem_size(self, fa3: Bool = False, persistent: Bool = False) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `kv_smem_size` `kv_smem_size(self, fa3: Bool = False) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `k_smem_size` `k_smem_size(self, fa3: Bool = False) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `v_smem_size` `v_smem_size(self, fa3: Bool = False) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `p_smem_size` `p_smem_size(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `warp_scratch_smem_size` `warp_scratch_smem_size(self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `shared_mem_bytes` `shared_mem_bytes[shared_kv: Bool = False, sm_90: Bool = False](self) -> UInt` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt) ### `__str__` `__str__(self) -> String` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String) ### `write_to` `write_to(self, mut writer: T)`
--- ## MHAPartitionScheme
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `accum_dtype` `comptime accum_dtype` ### `do_partition` `comptime do_partition` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `num_partitions` `num_partitions(self: _Self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_exp_sum_qk_max_pointer` `get_exp_sum_qk_max_pointer(self: _Self) -> LegacyUnsafePointer[Scalar[_Self.accum_dtype]]` **Returns:** `LegacyUnsafePointer` ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## NoPartition
`@register_passable(trivial)` `struct NoPartition[dtype: DType]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAPartitionScheme`](/mojo/kernels/nn/mha_utils/MHAPartitionScheme), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_dtype` `comptime accum_dtype = dtype` ### `do_partition` `comptime do_partition = False` ## Methods ### `__init__` `__init__() -> Self` ### `num_partitions` `num_partitions(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_exp_sum_qk_max_pointer` `get_exp_sum_qk_max_pointer(self) -> LegacyUnsafePointer[Scalar[NoPartition[dtype].accum_dtype]]` **Returns:** `LegacyUnsafePointer`
--- ## OptionallyStaticInt
## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ### `static_value` `comptime static_value` ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `as_uint32` `as_uint32(self: _Self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `__int__` `__int__(self: _Self) -> Int` Get the integral representation of the value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integral representation of the value. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## SplitKPartition
`@register_passable(trivial)` `struct SplitKPartition[dtype: DType]` ## Fields * ​ptr (`LegacyUnsafePointer[Scalar[SplitKPartition[dtype].accum_dtype]]`): * ​num\_partitions\_value (`UInt32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MHAPartitionScheme`](/mojo/kernels/nn/mha_utils/MHAPartitionScheme), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_dtype` `comptime accum_dtype = dtype` ### `do_partition` `comptime do_partition = True` ## Methods ### `__init__` `__init__(ptr: LegacyUnsafePointer[Scalar[SplitKPartition[dtype].accum_dtype]], num_partitions_value: UInt32) -> Self` ### `num_partitions` `num_partitions(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) ### `get_exp_sum_qk_max_pointer` `get_exp_sum_qk_max_pointer(self) -> LegacyUnsafePointer[Scalar[SplitKPartition[dtype].accum_dtype]]` **Returns:** `LegacyUnsafePointer`
--- ## StaticInt
`@register_passable(trivial)` `struct StaticInt[value: Int]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`OptionallyStaticInt`](/mojo/kernels/nn/mha_utils/OptionallyStaticInt), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `static_value` `comptime static_value = OptionalReg[Int](value)` ## Methods ### `__init__` `__init__() -> Self` ### `__int__` `__int__(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `as_uint32` `as_uint32(self) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## dispatch_mask_and_score_mod
`dispatch_mask_and_score_mod[mask_type: String, score_mod_type: String, callback_fn: callback_fn_type, local_window_size: Int = -1, num_heads: Int = -1]()`
--- ## dispatch_materialized_mask_and_score_mod
`dispatch_materialized_mask_and_score_mod[dtype: DType, layout: Layout, //, score_mod_type: String, callback_fn: callback_fn_type, num_heads: Int = -1](mask_nd: LayoutTensor[dtype, layout, MutAnyOrigin], start_pos_nd: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## get_start_and_end_for_partitions
`get_start_and_end_for_partitions[tile_size: Int](num_keys: Int, num_partitions: Int, partition_idx: Int) -> Tuple[Int, Int]` Calculate start and end indices for a partition. **Args:** * ​num\_keys ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of keys (sequence length). * ​num\_partitions ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of partitions to split keys into. * ​partition\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Index of current partition (0 to num\_partitions-1). **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple of (start\_idx, end\_idx) for the partition, aligned to tile\_size.
--- ## mha_utils
## `comptime` values ### `callback_fn_type` `comptime callback_fn_type = fn[mask_t: MHAMask, score_mod_t: ScoreModTrait](mask: mask_t, score_mod: score_mod_t) raises capturing -> None` ### `is_sm100` `comptime is_sm100 = String(_accelerator_arch()).__contains__("sm_100")` ### `is_sm90` `comptime is_sm90 = String(_accelerator_arch()).__contains__("sm_90")` ### `is_sm90or100` `comptime is_sm90or100 = is_sm90 if String(_accelerator_arch()).__contains__("sm_90") else is_sm100` ## Structs * [​`DynamicInt`](./DynamicInt): * [​`FlashAttentionAlgorithm`](./FlashAttentionAlgorithm): * [​`MHAConfig`](./MHAConfig): * [​`NoPartition`](./NoPartition): * [​`SplitKPartition`](./SplitKPartition): * [​`StaticInt`](./StaticInt): ## Traits * [​`MHAPartitionScheme`](./MHAPartitionScheme): * [​`OptionallyStaticInt`](./OptionallyStaticInt): ## Functions * [​`dispatch_mask_and_score_mod`](./dispatch_mask_and_score_mod): * [​`dispatch_materialized_mask_and_score_mod`](./dispatch_materialized_mask_and_score_mod): * [​`get_start_and_end_for_partitions`](./get_start_and_end_for_partitions): Calculate start and end indices for a partition.
--- ## flare_mla_decoding
`flare_mla_decoding[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig[dtype] = MHAConfig[dtype](UInt(Int.__init__[IntTuple](q_layout.shape[(rank - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(rank - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), ragged: Bool = False, decoding_warp_split_k: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None, num_partitions: OptionalReg[Int] = None)` MLA decoding kernel that would only be called in the optimized compute graph. The Q input has a shape of \[seq\_len, num\_heads, depth]. The K input has a shape of \[seq\_len, 1, depth]. The V tensor is derived by reusing K, where V = K\[:, :, :depth\_v]. Specifically, for DeepSeek V2/3, depth = 576 and depth\_v = 512. This kernel computes attention without needing to load V twice. This kernel only handles decoding requests. In this case q\_max\_seq\_len = 1. This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid\_length argument. `flare_mla_decoding[mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, q_layout: Layout, //, use_score_mod: Bool = False, config: MHAConfig[dtype] = MHAConfig[dtype](UInt(Int.__init__[IntTuple](q_layout.shape[2])), UInt(Int.__init__[IntTuple](q_layout.shape[3])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), decoding_warp_split_k: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask_functor: mask_t, score_mod_functor: score_mod_t, scale: Float32, ctx: DeviceContext, num_partitions: OptionalReg[Int] = None)`
--- ## flare_mla_decoding_dispatch
`flare_mla_decoding_dispatch[k_t: MHAOperand, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, q_layout: Layout, //, kv_num_heads: Int, use_score_mod: Bool = False, config: MHAConfig[dtype] = MHAConfig[dtype](UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), ragged: Bool = False, _is_cache_length_accurate: Bool = False, _use_valid_length: Bool = True, decoding_warp_split_k: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: k_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_prompt_len: Int, max_cache_valid_length: Int, scale: Float32, ctx: DeviceContext, kv_input_row_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None, num_partitions: OptionalReg[Int] = None)`
--- ## flare_mla_prefill
`flare_mla_prefill[rank: Int, cache_t: KVCacheT, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, output_type: DType, softmax_type: DType, q_layout: Layout, //, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False, use_fa4: Bool = False](output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_rope: cache_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, softmax_info: OptionalReg[LayoutTensor[softmax_type, Layout.row_major[3](), MutAnyOrigin]] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None, prev_output: OptionalReg[LayoutTensor[output_type, Layout.row_major[rank](), MutAnyOrigin]] = None, prev_softmax_info: OptionalReg[LayoutTensor[softmax_type, Layout.row_major[3](), MutAnyOrigin]] = None)` MLA prefill kernel that would only be called in the optimized compute graph. Only supports ragged Q/K/V inputs. The Q input has a shape of \[seq\_len, num\_heads, q\_depth]. The K and V input has a shape of \[cache\_len, num\_heads, depth]. The K\_rope input is retrieved from the KV cache, with a shape of \[cache\_len, 1, q\_depth - depth]. Specifically, for DeepSeek V2/3, depth = 128 and q\_depth = 192. When computing attention scores (Q @ K), each head of K is smaller than Q head. The missing 64 elements of K are retrieved from the K cache, and broadcasted to all the heads. This kernel also handles that output has reduced dimension compared to input Q. This kernel handles batches with different valid lengths (i.e., before the padding). Such lengths are passed in valid\_length argument. `flare_mla_prefill[rank: Int, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, softmax_type: DType, q_layout: Layout, //, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False, use_fa4: Bool = False](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], v: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_rope: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, ctx: DeviceContext, q_max_seq_len: OptionalReg[Int] = None, softmax_info: OptionalReg[LayoutTensor[softmax_type, Layout.row_major[3](), MutAnyOrigin]] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## flare_mla_prefill_dispatch
`flare_mla_prefill_dispatch[rank: Int, k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, mask_t: MHAMask, score_mod_t: ScoreModTrait, dtype: DType, output_type: DType, softmax_type: DType, q_layout: Layout, //, kv_num_heads: Int, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False, q_depth: Int = 192, cache_depth: Int = 576, config: MHAConfig[dtype] = MHAConfig[dtype](UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 2)])), UInt(Int.__init__[IntTuple](q_layout.shape[(q_layout.rank() - 1)])), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), OptionalReg[UInt](None), 4, 1, FlashAttentionAlgorithm(-1), TensorMapSwizzle.SWIZZLE_128B), _ndbuffer_mha_operand: Bool = False, use_fa4: Bool = False](output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: k_t, v: v_t, k_rope: k_rope_t, mask_functor: mask_t, score_mod_functor: score_mod_t, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_prompt_len: Int, scale: Float32, ctx: DeviceContext, softmax_info: OptionalReg[LayoutTensor[softmax_type, Layout.row_major[3](), MutAnyOrigin]] = None, cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]] = None, prev_output: OptionalReg[LayoutTensor[output_type, Layout.row_major[rank](), MutAnyOrigin]] = None, prev_softmax_info: OptionalReg[LayoutTensor[softmax_type, Layout.row_major[3](), MutAnyOrigin]] = None)`
--- ## mla (Mla)
## Functions * [​`flare_mla_decoding`](./flare_mla_decoding): MLA decoding kernel that would only be called in the optimized compute graph. * [​`flare_mla_decoding_dispatch`](./flare_mla_decoding_dispatch): * [​`flare_mla_prefill`](./flare_mla_prefill): MLA prefill kernel that would only be called in the optimized compute graph. Only supports ragged Q/K/V inputs. * [​`flare_mla_prefill_dispatch`](./flare_mla_prefill_dispatch): * [​`mla_decoding`](./mla_decoding): * [​`mla_decoding_single_batch`](./mla_decoding_single_batch): Flash attention v2 algorithm. * [​`mla_prefill`](./mla_prefill): * [​`mla_prefill_plan`](./mla_prefill_plan): This calls a GPU kernel that plans how to process a batch of sequences with varying lengths using a fixed-size buffer. * [​`mla_prefill_plan_kernel`](./mla_prefill_plan_kernel): * [​`mla_prefill_single_batch`](./mla_prefill_single_batch): MLA for encoding where seqlen > 1. * [​`set_buffer_lengths_to_zero`](./set_buffer_lengths_to_zero):
--- ## mla_decoding
`mla_decoding[q_type: DType, k_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, valid_layout: Layout, BM: UInt, BN: UInt, BK: UInt, WM: UInt, WN: UInt, depth: UInt, num_heads: UInt, num_threads: UInt, num_pipeline_stages: UInt, group: UInt = 1, use_score_mod: Bool = False, ragged: Bool = False, _use_valid_length: Bool = False, _is_cache_length_accurate: Bool = False, decoding_warp_split_k: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], scale: Float32, batch_size: Int, num_partitions: Int, max_cache_valid_length: Int, valid_length: LayoutTensor[DType.uint32, valid_layout, MutAnyOrigin], mask: mask_t, score_mod: score_mod_t)`
--- ## mla_decoding_single_batch
`mla_decoding_single_batch[q_type: DType, k_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, BM: UInt, BN: UInt, BK: UInt, WM: UInt, WN: UInt, depth: UInt, depth_v: UInt, num_heads: UInt, num_threads: UInt, num_pipeline_stages: UInt, group: UInt = 1, use_score_mod: Bool = False, decoding_warp_split_k: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], exp_sum_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], qk_max_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], scale: Float32, num_keys: UInt, num_partitions: UInt, max_cache_valid_length: UInt, mask: mask_t, score_mod: score_mod_t, batch_idx: Int)` Flash attention v2 algorithm.
--- ## mla_prefill
`mla_prefill[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, output_type: DType, softmax_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, valid_layout: Layout, config: MHAConfig[dtype], group: Int = 128, q_depth: Int = 192, cache_depth: Int = 576, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False, _ndbuffer_mha_operand: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, k_rope: k_rope_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], softmax_info_ptr: LegacyUnsafePointer[Scalar[softmax_type]], prev_output_ptr: LegacyUnsafePointer[Scalar[output_type]], prev_softmax_info_ptr: LegacyUnsafePointer[Scalar[softmax_type]], scale: Float32, batch_size: Int, seq_len_arg: Int, valid_length: LayoutTensor[DType.uint32, valid_layout, MutAnyOrigin], cache_offsets: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1), MutAnyOrigin]], mask: mask_t, score_mod: score_mod_t)`
--- ## mla_prefill_plan
`mla_prefill_plan[cache_t: KVCacheT](buffer_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], buffer_lengths: LayoutTensor[DType.int32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_cache: cache_t, buffer_token_size: UInt32, ctx: DeviceContext)` This calls a GPU kernel that plans how to process a batch of sequences with varying lengths using a fixed-size buffer. Each sequence in the batch has some existing cached tokens and new input tokens. The kernel divides the total tokens into chunks of buffer\_token\_size. For each chunk (iteration), it calculates: 1\. Buffer offsets for each sequence in each chunk 2\. Cache offsets for each sequence in each chunk 3\. Total buffer lengths for each processing iteration
--- ## mla_prefill_plan_kernel
`mla_prefill_plan_kernel[buffer_row_offsets_layout: Layout, cache_offsets_layout: Layout, buffer_lengths_layout: Layout, input_row_offsets_layout: Layout, cache_t: KVCacheT](buffer_row_offsets: LayoutTensor[DType.uint32, buffer_row_offsets_layout, MutAnyOrigin], cache_offsets: LayoutTensor[DType.uint32, cache_offsets_layout, MutAnyOrigin], buffer_lengths: LayoutTensor[DType.int32, buffer_lengths_layout, MutAnyOrigin], input_row_offsets: LayoutTensor[DType.uint32, input_row_offsets_layout, MutAnyOrigin], k_cache: cache_t, buffer_token_size: UInt32)`
--- ## mla_prefill_single_batch
`mla_prefill_single_batch[q_type: DType, k_t: MHAOperand, v_t: MHAOperand, k_rope_t: MHAOperand, output_type: DType, mask_t: MHAMask, score_mod_t: ScoreModTrait, *, config: MHAConfig[dtype], group: Int = 1, q_depth: Int = 192, cache_depth: Int = 576, use_score_mod: Bool = False, write_softmax_info: Bool = False, use_cascade_attention: Bool = False](q_ptr: LegacyUnsafePointer[Scalar[q_type]], k: k_t, v: v_t, k_rope: k_rope_t, output_ptr: LegacyUnsafePointer[Scalar[output_type]], softmax_info_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], prev_output_ptr: LegacyUnsafePointer[Scalar[output_type]], prev_softmax_info_ptr: LegacyUnsafePointer[Scalar[get_accum_type[q_type]()]], scale: Float32, seq_len: Int, max_seq_len: Int, start_pos: UInt32, cache_start_pos: UInt32, num_keys: Int, mask: mask_t, score_mod: score_mod_t, batch_idx: Int)` MLA for encoding where seqlen > 1.
--- ## set_buffer_lengths_to_zero
`set_buffer_lengths_to_zero[buffer_lengths_layout: Layout](buffer_lengths: LayoutTensor[DType.int32, buffer_lengths_layout, MutAnyOrigin])`
--- ## mla_graph
## Functions * [​`mla_prefill_branch_fp8`](./mla_prefill_branch_fp8): This is a manually fused kernel that performs the following operations: - Copy the KV latent values from PagedKVCache to a contiguous buffer. - Quantize the KV latent values to fp8. - Up-project the latent KV values to full K and V through a matmul. - Split the concatenated KV into K and V. - Perform MLA prefill.
--- ## mla_prefill_branch_fp8
`mla_prefill_branch_fp8[dtype: DType, fp8_dtype: DType, fp8_scale_dtype: DType, collection_t: KVCollectionT, //, qk_nope_head_dim: Int, m_scale_granularity: Int, n_scale_granularity: Int, k_scale_granularity: Int, mask_str: StringSlice[StaticConstantOrigin], score_mod_str: StringSlice[StaticConstantOrigin], target: StringSlice[StaticConstantOrigin] = "cpu"](output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_collection: collection_t, layer_idx: UInt32, scale: Float32, buffer_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], buffer_length: Int, kv_b_proj: LayoutTensor[fp8_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_b_proj_scale: LayoutTensor[fp8_scale_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContext)` This is a manually fused kernel that performs the following operations: - Copy the KV latent values from PagedKVCache to a contiguous buffer. - Quantize the KV latent values to fp8. - Up-project the latent KV values to full K and V through a matmul. - Split the concatenated KV into K and V. - Perform MLA prefill. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the input and output tensors. * ​fp8\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the fp8 input and output tensors. * ​fp8\_scale\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the fp8 scale input and output tensors. * ​collection\_t ([`KVCollectionT`](/mojo/kernels/kv_cache/types/KVCollectionT)): Type of the KV collection. * ​qk\_nope\_head\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): Dimension of non-rope parts of the Q/K heads. * ​m\_scale\_granularity ([`Int`](/mojo/stdlib/builtin/int/Int)): Granularity of the scale for M dimension of the matrix multiplication. * ​n\_scale\_granularity ([`Int`](/mojo/stdlib/builtin/int/Int)): Granularity of the scale for N dimension of the matrix multiplication. * ​k\_scale\_granularity ([`Int`](/mojo/stdlib/builtin/int/Int)): Granularity of the scale for K dimension of the matrix multiplication. * ​mask\_str (`StringSlice`): Mask variant. * ​score\_mod\_str (`StringSlice`): Positional encoding variant. * ​target (`StringSlice`): Target device. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor of shape \[tot\_seq\_len, num\_heads, v\_head\_dim]. * ​q ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Query tensor of shape \[tot\_seq\_len, num\_heads, qk\_nope\_head\_dim + qk\_rope\_head\_dim]. * ​input\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Indicates where each request starts and ends in `q`. Shape: [num\_batches + 1]. * ​kv\_collection (`collection_t`): Paged KV Cache object. * ​layer\_idx ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Layer index. * ​scale ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): Scale for the attention calculation. * ​buffer\_row\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Indicates where each request's KV latent values should be stored in the contiguous K buffer. This is a 1D tensor of shape \[num\_batches + 1]. * ​cache\_offsets ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Indicates the starting token position in the KV cache from which to copy KV latent values for each request. This is a 1D tensor of shape \[num\_batches + 1]. * ​buffer\_length ([`Int`](/mojo/stdlib/builtin/int/Int)): The total number of tokens in the KV cache. Scalar. * ​kv\_b\_proj ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Weight matrix for up-projecting the KV latent values to full K and V. Shape: [num\_heads \* (qk\_nope\_head\_dim + v\_head\_dim), kv\_latent\_dim]. * ​kv\_b\_proj\_scale ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The scale for the weight matrix. Shape varies depending on the float8\_config. * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): Device context.
--- ## MLAKVProducerPipeline
`@register_passable(trivial)` `struct MLAKVProducerPipeline[dtype: DType, config: FA4Config]` ## Fields * ​kv\_pipeline (`KVPipeline[config.num_kv_stages, config.num_mma_stages]`): * ​smem (`MLAKVProducerPipeline[dtype, config].SMemType`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `k_bytes` `comptime k_bytes = (MLAKVProducerPipeline[dtype, config].k_elements * size_of[dtype]())` ### `k_elements` `comptime k_elements = tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]().size()` ### `k_layout` `comptime k_layout = tile_layout_k_major[dtype, config.BN, 128, config.swizzle_mode]()` ### `k_rope_layout` `comptime k_rope_layout = tile_layout_k_major[dtype, config.BN, 64, config.swizzle_mode]()` ### `KPairType` `comptime KPairType = TMADestination[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode]()]` ### `KType` `comptime KType = LayoutTensor[dtype, tile_layout_k_major[dtype, config.BN, config.BK0, config.swizzle_mode](), MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]` ### `SMemType` `comptime SMemType = LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]` ### `v_bytes` `comptime v_bytes = (MLAKVProducerPipeline[dtype, config].v_elements * size_of[dtype]())` ### `v_elements` `comptime v_elements = tile_layout_mn_major[dtype, 128, config.BK1, config.swizzle_mode]().size()` ### `VPairType` `comptime VPairType = TMADestination[dtype, tile_layout_mn_major[dtype, 128, config.BK1, config.swizzle_mode]()]` ### `VType` `comptime VType = LayoutTensor[dtype, tile_layout_mn_major[dtype, 128, config.BK1, config.swizzle_mode](), MutAnyOrigin, address_space=AddressSpace.SHARED, layout_int_type=DType.int32, linear_idx_type=DType.int32, alignment=128]` ## Methods ### `__init__` `__init__(mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], smem: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]) -> Self` `__init__(kv_pipeline: KVPipeline[config.num_kv_stages, config.num_mma_stages], smem: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED]) -> Self` ### `init` `init(self)` Only one of the producer or consumer should call `init()`. ### `get_kv_smem` `get_kv_smem[*, mma_stage: Int](self) -> MLAKVProducerPipeline[dtype, config].SMemType` **Returns:** `MLAKVProducerPipeline` ### `get_k` `get_k[*, mma_stage: Int, expect: Bool = True](self) -> MLAKVProducerPipeline[dtype, config].KPairType` **Returns:** `MLAKVProducerPipeline` ### `get_v` `get_v[*, mma_stage: Int](self) -> MLAKVProducerPipeline[dtype, config].VPairType` **Returns:** `MLAKVProducerPipeline` ### `acquire_kv` `acquire_kv[*, mma_stage: Int = (config - 1)](self)` ### `commit_kv_step` `commit_kv_step(mut self)` Step the kv pipeline. The does not perform the commit on the mbars; that should be handled by the `tma_op.async_copy`.
--- ## SM100MLA
`@register_passable(trivial)` `struct SM100MLA[KVLUTType: MHAOperand, output_type: DType, MaskType: MHAMask, ScoreModType: ScoreModTrait, SchedulerType: MHATileScheduler, config: FA4Config, use_score_mod: Bool, ValidLengthType: OptionalPointer, SinkType: OptionalPointer, KVRowOffsetsType: OptionalPointer, _is_cache_length_accurate: Bool, MaxSeqLenType: OptionallyStaticInt, PartitionType: MHAPartitionScheme, descriptor_shape: IndexList[3], remaining_global_dim_rank: Int]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `accum_type` `comptime accum_type = get_accum_type[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type]()` ### `BM` `comptime BM = config.BM` ### `BN` `comptime BN = config.BN` ### `cache_depth` `comptime cache_depth = 576` ### `cta_group` `comptime cta_group = 1` ### `depth` `comptime depth = config.depth` ### `group` `comptime group = config.group` ### `k_bytes` `comptime k_bytes = SIMD[DType.uint32, 1]((SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].swizzle_granularity * config)).__rmul__[DType.uint32, 1](SIMD[DType.uint32, 1](SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_dt_size))` ### `k_elements` `comptime k_elements = SIMD[DType.uint32, 1]((SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].swizzle_granularity * config))` ### `k_rope_depth` `comptime k_rope_depth = 64` ### `kv_depth` `comptime kv_depth = (config - 64)` ### `KVPipelineType` `comptime KVPipelineType = KVPipeline[config.num_kv_stages, config.num_mma_stages]` ### `MMA_K` `comptime MMA_K = 16` ### `MMA_M` `comptime MMA_M = (config // 2)` ### `num_m_mmas` `comptime num_m_mmas = 2` ### `num_mma_stages` `comptime num_mma_stages = config.num_mma_stages` ### `num_q_heads` `comptime num_q_heads = config.num_q_heads` ### `OPipelineType` `comptime OPipelineType = MBarPipeline[2]` ### `padded_depth` `comptime padded_depth = config.padded_depth` ### `page_size` `comptime page_size = KVLUTType.page_size` ### `PositionType` `comptime PositionType = MHAPosition[config.BM, config.BN, config.depth, config.padded_depth, config.num_q_heads, config.group, _is_decoding[MaxSeqLenType]()]` ### `qkv_dt_size` `comptime qkv_dt_size = size_of[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type]()` ### `qkv_type` `comptime qkv_type = KVLUTType.dtype` ### `qo_bytes` `comptime qo_bytes = SIMD[DType.uint32, 1]((SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_dt_size * SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qo_elements))` ### `qo_elements` `comptime qo_elements = (SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].padded_depth * SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].MMA_M)` ### `ragged` `comptime ragged = ValidLengthType.is_null.__bool__().__invert__()` ### `simd_size` `comptime simd_size = simd_width_of[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type]()` ### `swizzle_granularity` `comptime swizzle_granularity = (config.swizzle_mode.bytes() // SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_dt_size)` ### `UMMA0Type` `comptime UMMA0Type = SM100TensorAccumulatorSS[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].MMA_M, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].BN, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].depth, swizzle_a=config.swizzle_mode, swizzle_b=config.swizzle_mode, num_stages=SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].num_mma_stages]` ### `UMMA1Type` `comptime UMMA1Type = SM100TensorAccumulatorTS[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].MMA_M, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].BN, config.swizzle_mode, transpose_b=False, num_stages=SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].num_mma_stages]` ### `v_bytes_per_mma` `comptime v_bytes_per_mma = SIMD[DType.uint32, 1](((SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_dt_size * 16) * config))` ## Methods ### `mla_prefill_kernel` `static mla_prefill_kernel[KRopeType: MHAOperand](q_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.BK0, decoding=False](), config, True), _ragged_desc_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.BK0, decoding=False](), config)], k_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, Tuple[]()), config)], k_rope_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, 64, Tuple[]()), TensorMapSwizzle.SWIZZLE_128B, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, 64, Tuple[]()), TensorMapSwizzle.SWIZZLE_128B)], v_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, Tuple[]()), config)], o_ptr_arg: LegacyUnsafePointer[Scalar[output_type]], ragged_tma_store: RaggedTensorMap[output_type, descriptor_shape, remaining_global_dim_rank, config.swizzle_mode], kv_lut: KVLUTType, scale: Float32, batch_size: UInt32, num_keys_arg: UInt32, pack: Pack[MaskType, ScoreModType, SchedulerType, ValidLengthType, SinkType, KVRowOffsetsType, MaxSeqLenType, PartitionType])` ### `correction` `static correction(tmem_addr: UInt32, mbars: FA4MiscMBars, o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], score_row: UInt32, num_keys: UInt32, mask: MaskType)` ### `softmax` `static softmax(tmem_addr: UInt32, warp_idx: UInt32, mbars: FA4MiscMBars, o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], score_row: UInt32, seq_info: SeqInfo, mask: MaskType, num_keys: UInt32, scale: Scalar[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type], score_mod: ScoreModType, max_seq_len: UInt32, o_ptr_arg: LegacyUnsafePointer[Scalar[output_type]], ragged_tma_store: RaggedTensorMap[output_type, descriptor_shape, remaining_global_dim_rank, config.swizzle_mode], o_smem: LegacyUnsafePointer[Scalar[output_type], address_space=AddressSpace.SHARED], sink_weights: SinkType)` ### `scale_write_output` `static scale_write_output(local_row: UInt32, inv_row_sum: Scalar[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type], o_smem: LegacyUnsafePointer[Scalar[output_type], address_space=AddressSpace.SHARED], o_tmem: TMemTile[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].accum_type, (SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].BM // 2), SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth], o_ptr: LegacyUnsafePointer[Scalar[output_type]], ragged_tma_store: RaggedTensorMap[output_type, descriptor_shape, remaining_global_dim_rank, config.swizzle_mode], warp_group_idx: UInt32, consumer_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], current_seq: Int, num_output_rows: Int32)` ### `mask_status` `static mask_status(mask: MaskType, score_row: UInt32, kv_row: UInt32) -> TileMaskStatus` **Returns:** `TileMaskStatus` ### `load` `static load(mbars: FA4MiscMBars, kv_pipeline_arg: KVPipeline[config.num_kv_stages, config.num_mma_stages], score_row: UInt32, num_keys: UInt32, seq_info: SeqInfo, max_seq_len: MaxSeqLenType, mask: MaskType, q_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.BK0, decoding=False](), config, True), _ragged_desc_layout[KVLUTType.dtype](q_smem_shape[KVLUTType.dtype, config.swizzle_mode, BM=(config // 2), group=config.group, depth=config.BK0, decoding=False](), config)], k_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, Tuple[]()), config)], k_rope_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, 64, Tuple[]()), TensorMapSwizzle.SWIZZLE_128B, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, 64, Tuple[]()), TensorMapSwizzle.SWIZZLE_128B)], v_tma_op: TMATensorTile[KVLUTType.dtype, _split_last_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, Tuple[]()), config, True), _ragged_desc_layout[KVLUTType.dtype](IndexList[3, DType.int64](config.BN, 1, SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].kv_depth, Tuple[]()), config)], kv_lut: KVLUTType, q_smem: LegacyUnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace.SHARED])` ### `descriptor_q` `static descriptor_q(q_smem: LegacyUnsafePointer[Scalar[SM100MLA[KVLUTType, output_type, MaskType, ScoreModType, SchedulerType, config, use_score_mod, ValidLengthType, SinkType, KVRowOffsetsType, _is_cache_length_accurate, MaxSeqLenType, PartitionType, descriptor_shape, remaining_global_dim_rank].qkv_type], address_space=AddressSpace.SHARED]) -> MMASmemDescriptorPair` **Returns:** [`MMASmemDescriptorPair`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptorPair) ### `mma` `static mma(tmem_addr: UInt32, mbars: FA4MiscMBars, kv_pipeline_arg: KVPipeline[config.num_kv_stages, config.num_mma_stages], o_mbar: LegacyUnsafePointer[SharedMemBarrier, address_space=AddressSpace.SHARED], score_row: UInt32, num_keys: UInt32, mask: MaskType, q_smem: LegacyUnsafePointer[Scalar[KVLUTType.dtype], address_space=AddressSpace.SHARED])`
--- ## mla_prefill_sm100
## Structs * [​`MLAKVProducerPipeline`](./MLAKVProducerPipeline): * [​`SM100MLA`](./SM100MLA): ## Functions * [​`mla_sm100_prefill`](./mla_sm100_prefill):
--- ## mla_sm100_prefill
`mla_sm100_prefill[output_type: DType, q_type: DType, KVType: MHAOperand, KRopeType: MHAOperand, MaskType: MHAMask, ScoreModType: ScoreModTrait, MaxPromptLenType: OptionallyStaticInt, //, config: MHAConfig[dtype], group: Int, q_depth: Int, cache_depth: Int, use_score_mod: Bool, _is_cache_length_accurate: Bool](output: LayoutTensor[output_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[q_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: KVType, v: KVType, k_rope: KRopeType, mask_functor: MaskType, score_mod_functor: ScoreModType, valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_prompt_len: MaxPromptLenType, scale: Float32, batch_size: Int, ctx: DeviceContext)`
--- ## calculate_warp_offset
`calculate_warp_offset[MaskType: DType](state: Bool) -> Tuple[UInt64, UInt64]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)
--- ## moe
## Functions * [​`calculate_warp_offset`](./calculate_warp_offset): * [​`moe_create_indices`](./moe_create_indices): * [​`moe_create_indices_bucket_group_kernel`](./moe_create_indices_bucket_group_kernel): Create indices for MoE routing using bucket sort algorithm. * [​`moe_create_indices_kernel`](./moe_create_indices_kernel):
--- ## moe_create_indices
`moe_create_indices[input_type: DType, //, target: StringSlice[StaticConstantOrigin], expected_count: Int = 8192](token_expert_order: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_start_indices: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], restore_token_order: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_ids: LayoutTensor[DType.int32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], expert_usage_stats: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], topk_ids: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr)`
--- ## moe_create_indices_bucket_group_kernel
`moe_create_indices_bucket_group_kernel[input_type: DType, token_expert_order_layout: Layout, expert_start_indices_layout: Layout, restore_token_order_layout: Layout, expert_ids_layout: Layout, expert_usage_stats_layout: Layout, topk_ids_layout: Layout, num_threads: Int = WARP_SIZE, expected_count: Int = 8192](token_expert_order: LayoutTensor[DType.uint32, token_expert_order_layout, MutAnyOrigin], lock: LayoutTensor[DType.uint32, Layout.row_major(1), MutAnyOrigin], expert_start_indices: LayoutTensor[DType.uint32, expert_start_indices_layout, MutAnyOrigin], restore_token_order: LayoutTensor[DType.uint32, restore_token_order_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_usage_stats: LayoutTensor[DType.uint32, expert_usage_stats_layout, MutAnyOrigin], topk_ids: LayoutTensor[input_type, topk_ids_layout, MutAnyOrigin])` Create indices for MoE routing using bucket sort algorithm. The main goal of this kernel is to group tokens that use the same expert together. This allows for efficient batching when used by other kernels such as grouped matmul. This is a GPU-optimized bucket sort implementation that uses: * Warp-level voting to count matching tokens * Shared memory for temporary storage * Atomic operations for thread-safe global memory updates topk\_ids: a 1D tensor of expert ids, the index of each expert\_id corresponds to a token. For example if topk\_ids is \[1, 0, 1, 3, 4, 2], then the corresponding tokens are \[0, 1, 2, 3, 4, 5] token\_expert\_order: a 1D tensor of tokens grouped together by expert id. Using the previous topk\_ids, the token expert order could be \[0, 2, 1, 3, 4, 5] expert\_ids: a 1D tensor of all the experts that are being used. Using the previous topk\_ids the our expert\_ids would be \[1, 0, 3, 4, 2] expert\_start\_indices: tells us where each expert starts and end in the token\_expert\_order. Based on the order of our expert\_ids our expert\_start\_indices would be \[0, 2, 3, 4, 5, 6]. So if you wanted to see where expert 1 starts and ends you would get the index 'i' of expert 1 in expert\_ids and would query expert\_start\_indices\[i] and query expert\_start\_indices\[i + 1] which is 0 and 2 respectively. lock: a 1D tensor that holds a single scalar value, this single integer will be used to atomically synchronize the writes back to global memory. It will do this by storing how many blocks have finished writing and the current global memory offset. expert\_usage\_stats: contains two values, the maximum number of tokens assigned to any expert and the number of active experts. For our example the stats would be \[2, 5] restore\_token\_order: a 1D tensor where each index represents a cooresponding token and holds the new index of the token in the token\_expert\_order tensor. For our example the restore\_token\_order would be \[0, 2, 1, 3, 4, 5]
--- ## moe_create_indices_kernel
`moe_create_indices_kernel[input_type: DType, num_threads: Int, token_expert_order_layout: Layout, expert_start_indices_layout: Layout, restore_token_order_layout: Layout, expert_ids_layout: Layout, expert_usage_stats_layout: Layout, indices_padded_layout: Layout, padded_input_layout: Layout, topk_ids_layout: Layout](token_expert_order: LayoutTensor[DType.uint32, token_expert_order_layout, MutAnyOrigin], expert_start_indices: LayoutTensor[DType.uint32, expert_start_indices_layout, MutAnyOrigin], restore_token_order: LayoutTensor[DType.uint32, restore_token_order_layout, MutAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutAnyOrigin], expert_usage_stats: LayoutTensor[DType.uint32, expert_usage_stats_layout, MutAnyOrigin], indices_padded: LayoutTensor[DType.uint32, indices_padded_layout, MutAnyOrigin], topk_ids_padded: LayoutTensor[input_type, padded_input_layout, MutAnyOrigin], topk_ids: LayoutTensor[input_type, topk_ids_layout, MutAnyOrigin])`
--- ## BoundingBox
`struct BoundingBox[dtype: DType]` Represents a 2D bounding box for object detection. The box is stored using two corner points: `nw` and `se`. **Note:** In this implementation, `nw` stores the maximum coordinates (max y, max x) and `se` stores the minimum coordinates (min y, min x). This differs from the typical interpretation of "northwest" (usually min x, max y) and "southeast" (usually max x, min y). This representation allows efficient computation of intersection and union areas. Fields: nw: Corner storing the maximum coordinates (max y, max x). se: Corner storing the minimum coordinates (min y, min x). ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for coordinate values. ## Fields * ​nw (`SIMD[dtype, 2]`): * ​se (`SIMD[dtype, 2]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, y1: Scalar[dtype], x1: Scalar[dtype], y2: Scalar[dtype], x2: Scalar[dtype])` Initialize a bounding box from two diagonal corner coordinates. Note: The corners are automatically ordered to ensure nw contains the maximum coordinates and se contains the minimum coordinates. **Args:** * ​y1 ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Y-coordinate of first corner. * ​x1 ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): X-coordinate of first corner. * ​y2 ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Y-coordinate of second corner. * ​x2 ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): X-coordinate of second corner. ### `iou` `iou(self, other: Self) -> Scalar[dtype]` Calculate Intersection over Union (IoU) with another bounding box. **Args:** * ​other (`Self`): The other bounding box to compare with. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The IoU value, ranging from 0 (no overlap) to 1 (perfect overlap). ### `intersection_area` `intersection_area(self, other: Self) -> Scalar[dtype]` Calculate the area of intersection with another bounding box. **Args:** * ​other (`Self`): The other bounding box to intersect with. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The intersection area, or 0 if boxes don't overlap. ### `area` `area(self) -> Scalar[dtype]` Calculate the area of this bounding box. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The area of the box.
--- ## nms
## Structs * [​`BoundingBox`](./BoundingBox): Represents a 2D bounding box for object detection. ## Functions * [​`non_max_suppression`](./non_max_suppression): Perform Non-Maximum Suppression (NMS) on bounding boxes. * [​`non_max_suppression_shape_func`](./non_max_suppression_shape_func): Compute the output shape for NMS without allocating the output buffer.
--- ## non_max_suppression
`non_max_suppression[dtype: DType](boxes: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scores: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[DType.int64, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_output_boxes_per_class: Int, iou_threshold: Float32, score_threshold: Float32)` Perform Non-Maximum Suppression (NMS) on bounding boxes. This is a buffer semantic overload that writes results directly to an output tensor. NMS iteratively selects boxes with highest scores while suppressing nearby boxes with high overlap (IoU). **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for box coordinates and scores. **Args:** * ​boxes ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Rank-3 tensor of bounding boxes with shape (batch, num\_boxes, 4). Each box is \[y1, x1, y2, x2]. * ​scores ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Rank-3 tensor of scores with shape (batch, num\_classes, num\_boxes). * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Rank-2 output tensor to store selected boxes as (N, 3) where each row is \[batch\_idx, class\_idx, box\_idx]. * ​max\_output\_boxes\_per\_class ([`Int`](/mojo/stdlib/builtin/int/Int)): Maximum number of boxes to select per class. * ​iou\_threshold ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): IoU threshold for suppression. Boxes with IoU > threshold are suppressed. * ​score\_threshold ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): Minimum score threshold. Boxes with score < threshold are filtered out. `non_max_suppression[dtype: DType, func: fn(Int64, Int64, Int64) capturing -> None](boxes: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scores: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_output_boxes_per_class: Int, iou_threshold: Float32, score_threshold: Float32)` Implements the NonMaxSuppression operator from the ONNX spec .
--- ## non_max_suppression_shape_func
`non_max_suppression_shape_func[dtype: DType](boxes: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scores: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_output_boxes_per_class: Int, iou_threshold: Float32, score_threshold: Float32) -> IndexList[2]` Compute the output shape for NMS without allocating the output buffer. This function performs a dry-run of NMS to determine how many boxes will be selected, allowing proper output buffer allocation. Can be removed once the graph compiler supports value semantic kernels that allocate their own output. **Args:** * ​boxes ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Rank-3 tensor of bounding boxes with shape (batch, num\_boxes, 4). * ​scores ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Rank-3 tensor of scores with shape (batch, num\_classes, num\_boxes). * ​max\_output\_boxes\_per\_class ([`Int`](/mojo/stdlib/builtin/int/Int)): Maximum number of boxes to select per class. * ​iou\_threshold ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): IoU threshold for suppression. * ​score\_threshold ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): Minimum score threshold. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): A 2-element IndexList specifying the output shape (num\_selected\_boxes, 3).
--- ## block_reduce
`block_reduce[dtype: DType, max_warps_per_block: Int](val: Scalar[dtype]) -> Scalar[dtype]` **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar)
--- ## group_norm
`group_norm[dtype: DType, rank: Int, input_fn: fn[width: Int, _rank: Int](IndexList[_rank]) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int](IndexList[1]) capturing -> SIMD[dtype, width], beta_fn: fn[width: Int](IndexList[1]) capturing -> SIMD[dtype, width], /, target: StringSlice[StaticConstantOrigin] = "gpu"](shape: IndexList[rank], epsilon: Scalar[dtype], groups: Int32, output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)`
--- ## group_norm_gpu
`group_norm_gpu[dtype: DType, rank: Int, //, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int](IndexList[1]) capturing -> SIMD[dtype, width], beta_fn: fn[width: Int](IndexList[1]) capturing -> SIMD[dtype, width]](shape: IndexList[rank, element_type=element_type], epsilon: Scalar[dtype], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], num_groups: Int, ctx: DeviceContext)`
--- ## group_norm_gpu_block
`group_norm_gpu_block[mut: Bool, origin: Origin[mut], layout: Layout, //, dtype: DType, simd_width: UInt, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int](IndexList[1]) capturing -> SIMD[dtype, width], beta_fn: fn[width: Int](IndexList[1]) capturing -> SIMD[dtype, width]](output: LayoutTensor[dtype, layout, origin], epsilon: Scalar[dtype], num_groups: Int, channels_per_group: Int, spatial: Int)`
--- ## group_norm_gpu_warp_tiling
`group_norm_gpu_warp_tiling[mut: Bool, origin: Origin[mut], layout: Layout, //, dtype: DType, simd_width: Int, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int](IndexList[1]) capturing -> SIMD[dtype, width], beta_fn: fn[width: Int](IndexList[1]) capturing -> SIMD[dtype, width]](output: LayoutTensor[dtype, layout, origin], epsilon: Scalar[dtype], num_groups: Int, channels_per_group: Int, spatial: Int)`
--- ## group_norm_reshape
`group_norm_reshape[dtype: DType, rank: Int](shape: IndexList[rank, element_type=element_type], buf: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], channels_per_group: Int, spatial: Int) -> LayoutTensor[dtype, Layout.row_major[2](), origin, address_space=address_space]` Reshapes an input buffer for group normalization by flattening all dimensions except the group dimension. Returns a 2D buffer of shape (num\_groups \* N, group\_size), where group\_size is the product of channels\_per\_group and spatial. **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## group_norm_shape
`group_norm_shape[dtype: DType, single_thread_blocking_override: Bool](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], gamma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], beta: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], num_groups: Int32) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## normalization
## Functions * [​`block_reduce`](./block_reduce): * [​`group_norm`](./group_norm): * [​`group_norm_gpu`](./group_norm_gpu): * [​`group_norm_gpu_block`](./group_norm_gpu_block): * [​`group_norm_gpu_warp_tiling`](./group_norm_gpu_warp_tiling): * [​`group_norm_reshape`](./group_norm_reshape): Reshapes an input buffer for group normalization by flattening all dimensions except the group dimension. Returns a 2D buffer of shape (num\_groups \* N, group\_size), where group\_size is the product of channels\_per\_group and spatial. * [​`group_norm_shape`](./group_norm_shape): * [​`layer_norm`](./layer_norm): * [​`layer_norm_cpu`](./layer_norm_cpu): Computes layernorm(elementwise\_fn(x)) across the last dimension of x, where layernorm is defined as $(x-mean(x))/(sqrt(var(x)+eps)*gamma_fn + beta$. * [​`layer_norm_gpu`](./layer_norm_gpu): * [​`layer_norm_gpu_block`](./layer_norm_gpu_block): * [​`layer_norm_gpu_warp_tiling`](./layer_norm_gpu_warp_tiling): * [​`layer_norm_reshape`](./layer_norm_reshape): * [​`layer_norm_shape`](./layer_norm_shape): Compute the output shape of a `layer_norm` operation. * [​`rms_norm`](./rms_norm): * [​`rms_norm_cpu`](./rms_norm_cpu): * [​`rms_norm_fused_residual_add`](./rms_norm_fused_residual_add): * [​`rms_norm_fused_residual_add_cpu`](./rms_norm_fused_residual_add_cpu): * [​`rms_norm_fused_residual_add_gpu`](./rms_norm_fused_residual_add_gpu): * [​`rms_norm_fused_residual_add_gpu_block`](./rms_norm_fused_residual_add_gpu_block): * [​`rms_norm_fused_residual_add_gpu_warp_tiling`](./rms_norm_fused_residual_add_gpu_warp_tiling): * [​`rms_norm_gpu`](./rms_norm_gpu): * [​`rms_norm_gpu_block`](./rms_norm_gpu_block): * [​`rms_norm_gpu_warp_tiling`](./rms_norm_gpu_warp_tiling): * [​`rms_norm_gpu_warp_tiling_128`](./rms_norm_gpu_warp_tiling_128): * [​`rms_norm_shape`](./rms_norm_shape): * [​`welford_block_all_reduce`](./welford_block_all_reduce): * [​`welford_combine`](./welford_combine): * [​`welford_update`](./welford_update): * [​`welford_warp_all_reduce`](./welford_warp_all_reduce): * [​`welford_warp_reduce`](./welford_warp_reduce):
--- ## layer_norm
`layer_norm[dtype: DType, rank: Int, input_0_fn: fn[_width: Int, _rank: Int](IndexList[_rank]) capturing -> SIMD[dtype, _width], input_1_fn: fn[_width: Int, _rank: Int](IndexList[_rank]) capturing -> SIMD[dtype, _width], output_0_fn: fn[width: Int, rank: Int, alignment: Int](idx: IndexList[rank], val: SIMD[dtype, width]) capturing -> None, /, target: StringSlice[StaticConstantOrigin] = "cpu"](shape: IndexList[rank], gamma_shape: IndexList[1], beta: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], ctx: DeviceContextPtr)`
--- ## layer_norm_cpu
`layer_norm_cpu[dtype: DType, //, input_fn: fn[width: Int](Int, Int) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None](num_rows: Int, num_cols: Int, beta: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype])` Computes layernorm(elementwise\_fn(x)) across the last dimension of x, where layernorm is defined as $(x-mean(x))/(sqrt(var(x)+eps)*gamma_fn + beta$. Currently performs 3 passes over the input data. This can be reduced to 2 by fusing the add, mean, and variance loops using Welford's algorithm. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The x and out buffers' elements dtype. * ​input\_fn (`fn[width: Int](Int, Int) capturing -> SIMD[dtype, width]`): Function called to generate an input value. * ​gamma\_fn (`fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): Function called to generate a gamma value. * ​output\_fn (`fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None`): Function called to store the output value. **Args:** * ​num\_rows ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of rows in the input tensor. * ​num\_cols ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of columns in the input tensor. * ​beta ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The beta value to use in the layernorm calculation. * ​epsilon ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The eps value to use in the layernorm calculation. `layer_norm_cpu[dtype: DType, rank: Int, //, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, rank: Int, alignment: Int](idx: IndexList[rank], val: SIMD[dtype, width]) capturing -> None](shape: IndexList[rank], beta: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype])`
--- ## layer_norm_gpu
`layer_norm_gpu[dtype: DType, rank: Int, //, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, rank: Int, alignment: Int](idx: IndexList[rank], val: SIMD[dtype, width]) capturing -> None](shape: IndexList[rank, element_type=element_type], beta: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], *, ctx: DeviceContext)`
--- ## layer_norm_gpu_block
`layer_norm_gpu_block[mut: Bool, origin: Origin[mut], layout: Layout, dtype: DType, //, simd_width: UInt, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None](shape: IndexList[2], beta: LayoutTensor[dtype, layout, origin], epsilon: Scalar[dtype])`
--- ## layer_norm_gpu_warp_tiling
`layer_norm_gpu_warp_tiling[mut: Bool, origin: Origin[mut], layout: Layout, dtype: DType, //, simd_width: UInt, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], gamma_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None](shape: IndexList[2], beta: LayoutTensor[dtype, layout, origin], epsilon: Scalar[dtype])`
--- ## layer_norm_reshape
`layer_norm_reshape[rank: Int, //, output_rank: Int](shape: IndexList[rank, element_type=element_type]) -> IndexList[output_rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## layer_norm_shape
`layer_norm_shape[dtype: DType, single_thread_blocking_override: Bool](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], gamma: LayoutTensor[dtype, Layout.row_major(1), origin], beta: LayoutTensor[dtype, Layout.row_major(1), origin], epsilon: Scalar[dtype]) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a `layer_norm` operation. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensors. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​gamma ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor for gamma coefficient. * ​beta ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor for beta coefficient. * ​epsilon ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The tensor for epsilon coefficient. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## rms_norm
`rms_norm[dtype: DType, rank: Int, input_0_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_0_fn: fn[width: Int, rank: Int, alignment: Int](idx: IndexList[rank], val: SIMD[dtype, width]) capturing -> None, /, target: StringSlice[StaticConstantOrigin] = "cpu", multiply_before_cast: Bool = True](shape: IndexList[rank], gamma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], ctx: DeviceContextPtr)`
--- ## rms_norm_cpu
`rms_norm_cpu[dtype: DType, //, input_fn: fn[width: Int](Int, Int) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](Int, Int, SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](gamma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], out_shape: IndexList[2])` `rms_norm_cpu[dtype: DType, rank: Int, //, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](shape: IndexList[rank], gamma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], weight_offset: Scalar[dtype])`
--- ## rms_norm_fused_residual_add
`rms_norm_fused_residual_add[dtype: DType, rank: Int, //, input_0_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], input_1_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_0_fn: fn[width: Int, rank: Int, alignment: Int](idx: IndexList[rank], val: SIMD[dtype, width]) capturing -> None, output_residual_fn: fn[width: Int, rank: Int, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, /, target: StringSlice[StaticConstantOrigin] = "cpu", multiply_before_cast: Bool = True](shape: IndexList[rank], gamma1: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon1: Scalar[dtype], weight_offset1: Scalar[dtype], gamma2: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon2: Scalar[dtype], weight_offset2: Scalar[dtype], ctx: DeviceContextPtr)`
--- ## rms_norm_fused_residual_add_cpu
`rms_norm_fused_residual_add_cpu[dtype: DType, rank: Int, //, input_0_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], residual_input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_0_fn: fn[width: Int, alignment: Int](idx: IndexList[rank], val: SIMD[dtype, width]) capturing -> None, output_residual_fn: fn[width: Int, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, /, multiply_before_cast: Bool = True](shape: IndexList[rank], gamma1: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon1: Scalar[dtype], weight_offset1: Scalar[dtype], gamma2: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon2: Scalar[dtype], weight_offset2: Scalar[dtype])`
--- ## rms_norm_fused_residual_add_gpu
`rms_norm_fused_residual_add_gpu[dtype: DType, rank: Int, //, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], residual_input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_residual_fn: fn[width: Int, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, output_fn: fn[width: Int, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](shape: IndexList[rank, element_type=element_type], gamma1: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon1: Scalar[dtype], weight_offset1: Scalar[dtype], gamma2: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon2: Scalar[dtype], weight_offset2: Scalar[dtype], ctx: DeviceContext)`
--- ## rms_norm_fused_residual_add_gpu_block
`rms_norm_fused_residual_add_gpu_block[mut1: Bool, origin1: Origin[mut1], layout1: Layout, mut2: Bool, origin2: Origin[mut2], layout2: Layout, dtype: DType, //, simd_width: Int, max_warps_per_block: Int, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], residual_input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None, output_residual_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](gamma1: LayoutTensor[dtype, layout1, origin1], epsilon1: Scalar[dtype], weight_offset1: Scalar[dtype], gamma2: LayoutTensor[dtype, layout2, origin2], epsilon2: Scalar[dtype], weight_offset2: Scalar[dtype], num_cols: Int)`
--- ## rms_norm_fused_residual_add_gpu_warp_tiling
`rms_norm_fused_residual_add_gpu_warp_tiling[mut1: Bool, origin1: Origin[mut1], layout1: Layout, mut2: Bool, origin2: Origin[mut2], layout2: Layout, dtype: DType, //, simd_width: Int, max_warps_per_block: Int, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], residual_input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None, output_residual_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](gamma1: LayoutTensor[dtype, layout1, origin1], epsilon1: Scalar[dtype], weight_offset1: Scalar[dtype], gamma2: LayoutTensor[dtype, layout2, origin2], epsilon2: Scalar[dtype], weight_offset2: Scalar[dtype], num_cols: Int)`
--- ## rms_norm_gpu
`rms_norm_gpu[dtype: DType, rank: Int, //, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](shape: IndexList[rank, element_type=element_type], gamma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], ctx: DeviceContext)`
--- ## rms_norm_gpu_block
`rms_norm_gpu_block[mut: Bool, origin: Origin[mut], layout: Layout, dtype: DType, //, simd_width: Int, max_warps_per_block: Int, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](gamma: LayoutTensor[dtype, layout, origin], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], num_cols: Int)`
--- ## rms_norm_gpu_warp_tiling
`rms_norm_gpu_warp_tiling[mut: Bool, origin: Origin[mut], layout: Layout, dtype: DType, //, simd_width: Int, max_warps_per_block: Int, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](gamma: LayoutTensor[dtype, layout, origin], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], num_cols: Int)`
--- ## rms_norm_gpu_warp_tiling_128
`rms_norm_gpu_warp_tiling_128[mut: Bool, origin: Origin[mut], layout: Layout, dtype: DType, //, simd_width: Int, warps_per_block: Int, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, alignment: Int](row: Int, col: Int, val: SIMD[dtype, width]) capturing -> None, multiply_before_cast: Bool](gamma: LayoutTensor[dtype, layout, origin], epsilon: Scalar[dtype], weight_offset: Scalar[dtype], num_rows: Int, num_cols: Int)`
--- ## rms_norm_shape
`rms_norm_shape[dtype: DType, single_thread_blocking_override: Bool](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], gamma: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], epsilon: Scalar[dtype], weight_offset: Scalar[dtype]) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## welford_block_all_reduce
`welford_block_all_reduce[dtype: DType, //](thread_mean: Scalar[dtype], thread_m2: Scalar[dtype], thread_count: Scalar[dtype], mut res_mean: Scalar[dtype], mut res_m2: Scalar[dtype], mut res_count: Scalar[dtype])`
--- ## welford_combine
`welford_combine[dtype: DType, //](mean: Scalar[dtype], m2: Scalar[dtype], count: Scalar[dtype], mut res_mean: Scalar[dtype], mut res_m2: Scalar[dtype], mut res_count: Scalar[dtype])`
--- ## welford_update
`welford_update[dtype: DType, //](val: Scalar[dtype], mut mean: Scalar[dtype], mut m2: Scalar[dtype], mut count: Scalar[dtype])`
--- ## welford_warp_all_reduce
`welford_warp_all_reduce[dtype: DType, //](thread_mean: Scalar[dtype], thread_m2: Scalar[dtype], thread_count: Scalar[dtype], mut res_mean: Scalar[dtype], mut res_m2: Scalar[dtype], mut res_count: Scalar[dtype])`
--- ## welford_warp_reduce
`welford_warp_reduce[dtype: DType, //](thread_mean: Scalar[dtype], thread_m2: Scalar[dtype], thread_count: Scalar[dtype], mut res_mean: Scalar[dtype], mut res_m2: Scalar[dtype], mut res_count: Scalar[dtype])`
--- ## pad (Pad)
## Functions * [​`pad_constant`](./pad_constant): Fill `output` with values from `input`, and edges padded with `constant` based on `paddings`. * [​`pad_reflect`](./pad_reflect): Fill `output` with values from `input`, and edges padded with reflected values from the unpadded region. * [​`pad_repeat`](./pad_repeat): Fill `output` with values from `input`, and edges padded boundary values from the unpadded region. * [​`pad_shape`](./pad_shape): Compute the output shape of a `pad` operation, and assert the inputs are compatible.
--- ## pad_constant
`pad_constant[output_layout: Layout, input_layout: Layout, dtype: DType, paddings_type: DType, constant_type: DType](output: LayoutTensor[dtype, output_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, input_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LegacyUnsafePointer[Scalar[paddings_type]], constant: Scalar[constant_type])` Fill `output` with values from `input`, and edges padded with `constant` based on `paddings`. Example: var input\_shape = (X, Y, Z) var paddings = [x0, x1, y0, y1, z0, z1] out\[x, y, z] = input\[x - x0, y - y0, z - z0] if x ∈ \[x0, x0 + X] && y ∈ \[y0, y0 + Y] && z ∈ \[z0, z0 + Z] else constant **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer. * ​paddings (`LegacyUnsafePointer`): Ordered (before, after) padding sizes for each axis. * ​constant ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The constant to pad output with.
--- ## pad_reflect
`pad_reflect[output_layout: Layout, input_layout: Layout, dtype: DType, paddings_type: DType](output: LayoutTensor[dtype, output_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, input_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LegacyUnsafePointer[Scalar[paddings_type]])` Fill `output` with values from `input`, and edges padded with reflected values from the unpadded region. Example: var input = [\[1, 2], \[3, 4]] var paddings = [2, 2, 1, 0] Yields: output = [\[2, 1, 2], \[4, 3, 4], \[2, 1, 2], \[4, 3, 4], \[2, 1, 2], \[4, 3, 4]] **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer. * ​paddings (`LegacyUnsafePointer`): Ordered (before, after) padding sizes for each axis.
--- ## pad_repeat
`pad_repeat[output_layout: Layout, input_layout: Layout, dtype: DType, paddings_type: DType](output: LayoutTensor[dtype, output_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LegacyUnsafePointer[Scalar[paddings_type]])` Fill `output` with values from `input`, and edges padded boundary values from the unpadded region. Example: var input = [\[1, 2], \[3, 4]] var paddings = [2, 2, 1, 0] Yields: output = [\[1, 1, 2], \[1, 1, 2], \[1, 1, 2], \[3, 3, 4], \[3, 3, 4], \[3, 3, 4]] **Parameters:** * ​output\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the output buffer. * ​input\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): Layout of the input buffer. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType of the input/output buffer. * ​paddings\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType of the input, output, and padding buffers. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer. * ​paddings (`LegacyUnsafePointer`): Ordered (before, after) padding sizes for each axis.
--- ## pad_shape
`pad_shape[input_type: DType, paddings_type: DType, single_thread_blocking_override: Bool](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings_buf: LayoutTensor[paddings_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a `pad` operation, and assert the inputs are compatible. **Parameters:** * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​paddings\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the padding tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The tensor to pad. * ​paddings\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The paddings tensor, of shape (input\_rank, 2). **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## get_padding_output_shape
`get_padding_output_shape[rank: Int](input_shape: IndexList[rank], paddings: LayoutTensor[DType.index, Layout(IntTuple((2 * rank))), origin]) -> IndexList[rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## get_row_offset
`get_row_offset[dtype: DType, tensor_layout: Layout](input_tensor: LayoutTensor[dtype, tensor_layout, MutAnyOrigin], output_tensor: LayoutTensor[dtype, tensor_layout, MutAnyOrigin], row_length: Int, row: Int) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## pad_gpu
## Functions * [​`get_padding_output_shape`](./get_padding_output_shape): * [​`get_row_offset`](./get_row_offset): * [​`pad_constant`](./pad_constant): Fill `output` with values from `input`, and edges padded with `constant` based on `paddings`. * [​`padded_copy_kernel`](./padded_copy_kernel): * [​`scalar_copy_row`](./scalar_copy_row): * [​`vector_copy_row`](./vector_copy_row):
--- ## pad_constant (Pad_gpu)
`pad_constant[rank: Int, dtype: DType, padding_type: DType](output: LegacyUnsafePointer[Scalar[dtype]], output_shape: IndexList[rank], input: LegacyUnsafePointer[Scalar[dtype]], input_shape: IndexList[rank], paddings: LegacyUnsafePointer[Scalar[padding_type]], constant: Scalar[dtype], ctx: DeviceContext)` Fill `output` with values from `input`, and edges padded with `constant` based on `paddings`. Example: ```mojo var input_shape = (X, Y, Z) var paddings = [x0, x1, y0, y1, z0, z1] out[x, y, z] = input[x - x0, y - y0, z - z0] if x ∈ [x0, x0 + X] && y ∈ [y0, y0 + Y] && z ∈ [z0, z0 + Z] else constant ``` **Args:** * ​output (`LegacyUnsafePointer`): The output buffer. * ​output\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The output shape. * ​input (`LegacyUnsafePointer`): The input buffer. * ​input\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The input shape. * ​paddings (`LegacyUnsafePointer`): Ordered (before, after) padding sizes for each axis. * ​constant ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The constant to pad output with. * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): Device context for participating GPU.
--- ## padded_copy_kernel
`padded_copy_kernel[dtype: DType, tensor_layout: Layout, simd_width: Int](input_tensor: LayoutTensor[dtype, tensor_layout, MutAnyOrigin], output_tensor: LayoutTensor[dtype, tensor_layout, MutAnyOrigin], rows_per_sm: Int, total_rows: Int, row_length: Int, scaled_row_length: Int)`
--- ## scalar_copy_row
`scalar_copy_row[dtype: DType](input_ptr: LegacyUnsafePointer[Scalar[dtype]], output_ptr: LegacyUnsafePointer[Scalar[dtype]], row_length: Int, threads_per_row: Int)`
--- ## vector_copy_row
`vector_copy_row[dtype: DType, simd_width: Int](input_ptr: LegacyUnsafePointer[Scalar[dtype]], output_ptr: LegacyUnsafePointer[Scalar[dtype]], scaled_row_length: Int, row_length: Int, threads_per_row: Int)`
--- ## PoolMethod
`@register_passable(trivial)` `struct PoolMethod` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AVG` `comptime AVG = PoolMethod(1)` ### `MAX` `comptime MAX = PoolMethod(0)` ## Methods ### `__eq__` `__eq__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `__ne__` `__ne__(self, rhs: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## avg_pool
`avg_pool[dtype: DType, int_type: DType, count_boundary: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ceil_mode: Bool = False, ctx_ptr: DeviceContextPtr = DeviceContextPtr())`
--- ## avg_pool_cpu
`avg_pool_cpu[dtype: DType, int_type: DType, rank: Int = 4, count_boundary: Bool = False](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ceil_mode: Bool = False)` Computes the average pool. Params: count\_boundary: Whether to count the boundary in the average computation. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Batched image input to the pool2d operator. * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Filter size on height and width dimensions with assumed tuple def (filter\_h, filter\_w). * ​strides ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Strides on height and width dimensions with assumed tuple def (stride\_h, stride\_w). * ​dilations ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Dilations on height and width dimensions with assumed tuple def (dilation\_h, dilation\_w). * ​paddings ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Paddings on height and width dimensions with assumed tuple def (pad\_h\_before, pad\_h\_after, pad\_w\_before, pad\_w\_after)). * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Pre-allocated output tensor space. * ​ceil\_mode ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Ceiling mode defines the output shape and implicit padding.
--- ## avg_pool_gpu
`avg_pool_gpu[dtype: DType, int_type: DType, count_boundary: Bool = False](ctx: DeviceContext, input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ceil_mode: Bool = False)` Computes the average pool on GPU. Params: count\_boundary: Whether to count the boundary in the average computation. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The DeviceContext to use for GPU execution. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On device) Batched image input to the pool2d operator. * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On host) Filter size on height and width dimensions with assumed tuple def (filter\_h, filter\_w). * ​strides ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On host) Strides on height and width dimensions with assumed tuple def (stride\_h, stride\_w). * ​dilations ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On host) Dilations on height and width dimensions with assumed tuple def (dilation\_h, dilation\_w). * ​paddings ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On host) Paddings on height and width dimensions with assumed tuple def (pad\_h\_before, pad\_h\_after, pad\_w\_before, pad\_w\_after)). * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On device) Pre-allocated output tensor space. * ​ceil\_mode ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Ceiling mode defines the output shape and implicit padding.
--- ## pool
## Structs * [​`PoolMethod`](./PoolMethod): ## Functions * [​`avg_pool`](./avg_pool): * [​`avg_pool_cpu`](./avg_pool_cpu): Computes the average pool. * [​`avg_pool_gpu`](./avg_pool_gpu): Computes the average pool on GPU. * [​`max_pool`](./max_pool): * [​`max_pool_cpu`](./max_pool_cpu): Computes fp32 pooling. * [​`max_pool_gpu`](./max_pool_gpu): Computes max pooling on GPU. * [​`pool_shape`](./pool_shape): * [​`pool_shape_ceil`](./pool_shape_ceil): * [​`pool_shape_impl`](./pool_shape_impl): Compute the output shape of a pooling operation, and assert the inputs are compatible. Works for 2D pool operations only in the NHWC format.
--- ## max_pool
`max_pool[dtype: DType, int_type: DType, target: StringSlice[StaticConstantOrigin] = "cpu"](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ceil_mode: Bool = False, ctx_ptr: DeviceContextPtr = DeviceContextPtr())`
--- ## max_pool_cpu
`max_pool_cpu[dtype: DType, int_type: DType](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ceil_mode: Bool = False)` Computes fp32 pooling. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Batched image input to the pool2d operator. * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Filter size on height and width dimensions with assumed tuple def (filter\_h, filter\_w). * ​strides ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Strides on height and width dimensions with assumed tuple def (stride\_h, stride\_w). * ​dilations ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Dilations on height and width dimensions with assumed tuple def (dilation\_h, dilation\_w). * ​paddings ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Paddings on height and width dimensions with assumed tuple def (pad\_h\_before, pad\_h\_after, pad\_w\_before, pad\_w\_after)). * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Pre-allocated output tensor space. * ​ceil\_mode ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Ceiling mode defines the output shape and implicit padding.
--- ## max_pool_gpu
`max_pool_gpu[dtype: DType, int_type: DType](ctx: DeviceContext, input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings: LayoutTensor[int_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ceil_mode: Bool = False)` Computes max pooling on GPU. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The DeviceContext to use for GPU execution. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On device) Batched image input to the pool2d operator. * ​filter ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On host) Filter size on height and width dimensions with assumed tuple def (filter\_h, filter\_w). * ​strides ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On host) Strides on height and width dimensions with assumed tuple def (stride\_h, stride\_w). * ​dilations ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On host) Dilations on height and width dimensions with assumed tuple def (dilation\_h, dilation\_w). * ​paddings ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On host) Paddings on height and width dimensions with assumed tuple def (pad\_h\_before, pad\_h\_after, pad\_w\_before, pad\_w\_after)). * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): (On device) Pre-allocated output tensor space. * ​ceil\_mode ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Ceiling mode defines the output shape and implicit padding.
--- ## pool_shape
`pool_shape[input_type: DType, filter_type: DType, strides_type: DType, dilations_type: DType, paddings_type: DType, single_thread_blocking_override: Bool](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter_buf: LayoutTensor[filter_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides_buf: LayoutTensor[strides_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations_buf: LayoutTensor[dilations_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings_buf: LayoutTensor[paddings_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## pool_shape_ceil
`pool_shape_ceil[input_type: DType, filter_type: DType, strides_type: DType, dilations_type: DType, paddings_type: DType, single_thread_blocking_override: Bool](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter_buf: LayoutTensor[filter_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides_buf: LayoutTensor[strides_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations_buf: LayoutTensor[dilations_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings_buf: LayoutTensor[paddings_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## pool_shape_impl
`pool_shape_impl[input_type: DType, filter_type: DType, strides_type: DType, dilations_type: DType, paddings_type: DType, single_thread_blocking_override: Bool, ceil_mode: Bool](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], filter_buf: LayoutTensor[filter_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], strides_buf: LayoutTensor[strides_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], dilations_buf: LayoutTensor[dilations_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], paddings_buf: LayoutTensor[paddings_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a pooling operation, and assert the inputs are compatible. Works for 2D pool operations only in the NHWC format. **Parameters:** * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​filter\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the filter tensor. * ​strides\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the strides tensor. * ​dilations\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the dilations tensor. * ​paddings\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the paddings tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​ceil\_mode ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Define rounding mode for shape calculation. **Args:** * ​input\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​filter\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The filter size buffer. * ​strides\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The strides size buffer. * ​dilations\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The dilations size buffer. * ​paddings\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The paddings size buffer. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## rand_normal
## Functions * [​`random_normal`](./random_normal): Call `output_fn` with values generated from a normal distribution with the specified mean and standard deviation.
--- ## random_normal
`random_normal[dtype: DType, rank: Int, //, output_fn: fn[width: Int, _rank: Int](idx: IndexList[_rank], val: SIMD[dtype, width]) capturing -> None, target: StringSlice[StaticConstantOrigin]](shape: IndexList[rank], mean: Float32, stddev: Float32, seed_value: UInt64, ctx: DeviceContextPtr)` Call `output_fn` with values generated from a normal distribution with the specified mean and standard deviation. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type to generate. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the underlying buffer. * ​output\_fn (`fn[width: Int, _rank: Int](idx: IndexList[_rank], val: SIMD[dtype, width]) capturing -> None`): The function which stores the generated values. * ​target (`StringSlice`): The target to run on. **Args:** * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the output being stored into by output\_fn. * ​mean ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): The mean of the normal distribution. * ​stddev ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): The standard deviation of the normal distribution. * ​seed\_value ([`UInt64`](/mojo/stdlib/builtin/simd/#uint64)): Seed value used to initialize the random number generator. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The device context.
--- ## rand_uniform
## Functions * [​`random_uniform`](./random_uniform): Call `output_fn` with values generated from a uniform distribution on \[lower\_bound, upper\_bound] for floating-point types or \[lower\_bound, upper\_bound) for integer types.
--- ## random_uniform
`random_uniform[dtype: DType, rank: Int, //, output_fn: fn[width: Int, _rank: Int](idx: IndexList[_rank], val: SIMD[dtype, width]) capturing -> None, target: StringSlice[StaticConstantOrigin]](shape: IndexList[rank], lower_bound: Scalar[dtype], upper_bound: Scalar[dtype], seed_value: UInt64, ctx: DeviceContextPtr)` Call `output_fn` with values generated from a uniform distribution on \[lower\_bound, upper\_bound] for floating-point types or \[lower\_bound, upper\_bound) for integer types. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type to generate. * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the underlying buffer. * ​output\_fn (`fn[width: Int, _rank: Int](idx: IndexList[_rank], val: SIMD[dtype, width]) capturing -> None`): The function which stores the generated values. * ​target (`StringSlice`): The target to run on. **Args:** * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the output being stored into by output\_fn. * ​lower\_bound ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The lower bound on the uniform range. * ​upper\_bound ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The upper bound on the uniform range. * ​seed\_value ([`UInt64`](/mojo/stdlib/builtin/simd/#uint64)): Seed value used to initialize the random number generator. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The device context.
--- ## randn
## Functions * [​`random_normal`](./random_normal): Fill `output` with values generated from Normal(mean, variance) distribution.
--- ## random_normal (Randn)
`random_normal[dtype: DType, mean: Float64, variance: Float64](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Fill `output` with values generated from Normal(mean, variance) distribution. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer.
--- ## repeat_interleave
## Functions * [​`repeat_interleave`](./repeat_interleave): Fill `output` by repeating values from `input` along `axis` based on the values in `repeats` buffer. * [​`repeat_interleave_shape`](./repeat_interleave_shape):
--- ## repeat_interleave (Repeat_interleave)
`repeat_interleave[dtype: DType, type_repeats: DType](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], repeats: LayoutTensor[type_repeats, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int, output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Fill `output` by repeating values from `input` along `axis` based on the values in `repeats` buffer. This is intended to implement the same functionality as torch.repeat: **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer. * ​repeats ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The number of repetitions each element in input. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis along which to repeat values. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer.
--- ## repeat_interleave_shape
`repeat_interleave_shape[type_repeats: DType](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], repeats: LayoutTensor[type_repeats, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## reshape
## Functions * [​`layout_tensor_reshape`](./layout_tensor_reshape): * [​`reshape`](./reshape): * [​`reshape_shape`](./reshape_shape):
--- ## layout_tensor_reshape
`layout_tensor_reshape[output_rank: Int, dtype: DType, single_thread_blocking_override: Bool](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], new_shape: IndexList[output_rank]) -> LayoutTensor[dtype, Layout.row_major[output_rank](), origin, address_space=address_space]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## reshape (Reshape)
`reshape[dtype: DType, //, output_rank: Int, single_thread_blocking_override: Bool = True](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], new_shape: IndexList[output_rank]) -> LayoutTensor[dtype, Layout.row_major[output_rank](), origin, address_space=address_space]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## reshape_shape
`reshape_shape[output_rank: Int, input_type: DType, target_shape_type: DType, single_thread_blocking_override: Bool](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], target_shape_buf: LayoutTensor[target_shape_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[output_rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## CoordinateTransformationMode
`struct CoordinateTransformationMode` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AlignCorners` `comptime AlignCorners = CoordinateTransformationMode(1)` ### `Asymmetric` `comptime Asymmetric = CoordinateTransformationMode(2)` ### `HalfPixel` `comptime HalfPixel = CoordinateTransformationMode(0)` ### `HalfPixel1D` `comptime HalfPixel1D = CoordinateTransformationMode(3)` ## Methods ### `__init__` `__init__(out self, value: Int)` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## InterpolationMode
`struct InterpolationMode` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Linear` `comptime Linear = InterpolationMode(0)` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## Interpolator
`@register_passable(trivial)` `struct Interpolator[mode: InterpolationMode]` ## Fields * ​cubic\_coeff (`Float32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(cubic_coeff: Float32) -> Self` `__init__() -> Self` ### `filter_length` `static filter_length() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `filter` `filter(self, x: Float32) -> Float32` **Returns:** [`Float32`](/mojo/stdlib/builtin/simd/#float32)
--- ## RoundMode
`struct RoundMode` ## Fields * ​value (`Int`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Ceil` `comptime Ceil = RoundMode(3)` ### `Floor` `comptime Floor = RoundMode(2)` ### `HalfDown` `comptime HalfDown = RoundMode(0)` ### `HalfUp` `comptime HalfUp = RoundMode(1)` ## Methods ### `__init__` `__init__(out self, value: Int)` ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool)
--- ## coord_transform
`coord_transform[mode: CoordinateTransformationMode](out_coord: Int, in_dim: Int, out_dim: Int, scale: Float32) -> Float32` **Returns:** [`Float32`](/mojo/stdlib/builtin/simd/#float32)
--- ## resize
## Structs * [​`CoordinateTransformationMode`](./CoordinateTransformationMode): * [​`InterpolationMode`](./InterpolationMode): * [​`Interpolator`](./Interpolator): * [​`RoundMode`](./RoundMode): ## Functions * [​`coord_transform`](./coord_transform): * [​`interpolate_point_1d`](./interpolate_point_1d): * [​`linear_filter`](./linear_filter): This is a tent filter. * [​`resize_linear`](./resize_linear): Resizes input to output shape using linear interpolation. * [​`resize_nearest_neighbor`](./resize_nearest_neighbor):
--- ## interpolate_point_1d
`interpolate_point_1d[in_layout: Layout, //, coordinate_transformation_mode: CoordinateTransformationMode, antialias: Bool, dtype: DType, interpolation_mode: InterpolationMode](interpolator: Interpolator[interpolation_mode], dim: Int, out_coords: IndexList[in_layout.rank()], scale: Float32, input: LayoutTensor[dtype, in_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## linear_filter
`linear_filter(x: Float32) -> Float32` This is a tent filter. f(x) = 1 + x, x < 0 f(x) = 1 - x, 0 <= x < 1 f(x) = 0, x >= 1 **Returns:** [`Float32`](/mojo/stdlib/builtin/simd/#float32)
--- ## resize_linear
`resize_linear[coordinate_transformation_mode: CoordinateTransformationMode, antialias: Bool, dtype: DType](input: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Resizes input to output shape using linear interpolation. Does not use anti-aliasing filter for downsampling (coming soon). **Parameters:** * ​coordinate\_transformation\_mode ([`CoordinateTransformationMode`](/mojo/kernels/nn/resize/CoordinateTransformationMode)): How to map a coordinate in output to a coordinate in input. * ​antialias ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether or not to use an antialiasing linear/cubic filter, which when downsampling, uses more points to avoid aliasing artifacts. Effectively stretches the filter by a factor of 1 / scale. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of input and output. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input to be resized. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output containing the resized input.
--- ## resize_nearest_neighbor
`resize_nearest_neighbor[coordinate_transformation_mode: CoordinateTransformationMode, round_mode: RoundMode, dtype: DType](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## Weighted2DPoint
`@register_passable(trivial)` `struct Weighted2DPoint[dtype: DType]` Utility class to wrap 2-d point coordinates and floating point weight for bilinear interpolation. ## Fields * ​y (`Int`): * ​x (`Int`): * ​w (`Scalar[dtype]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(y: Int, x: Int, weight: Scalar[dtype]) -> Self`
--- ## roi_align
## Structs * [​`Weighted2DPoint`](./Weighted2DPoint): Utility class to wrap 2-d point coordinates and floating point weight for bilinear interpolation. ## Functions * [​`roi_align_nhwc`](./roi_align_nhwc): Compute ROIAlign a batch of rois of shape \[M, 5] where the first dim is the batch index, followed by region box coordinates (y0, x0) (y1, x1). For inputs of NHWC format. The output shape is \[M, output\_height, output\_width, C].
--- ## roi_align_nhwc
`roi_align_nhwc[dtype: DType, output_layout: Layout, input_layout: Layout, roi_layout: Layout, //, aligned: Bool, mode: StringSlice[StaticConstantOrigin] = "AVG"](output: LayoutTensor[dtype, output_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], rois: LayoutTensor[dtype, roi_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_height: Int, output_width: Int, in_spatial_scale: Scalar[dtype], in_sampling_ratio: Scalar[dtype])` Compute ROIAlign a batch of rois of shape \[M, 5] where the first dim is the batch index, followed by region box coordinates (y0, x0) (y1, x1). For inputs of NHWC format. The output shape is \[M, output\_height, output\_width, C]. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​output\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The output layout. * ​input\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The input layout. * ​roi\_layout ([`Layout`](/mojo/kernels/layout/layout/Layout)): The layout of the regions of interests (ROI). * ​aligned ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If not true offset the ROIs by 0.5. * ​mode (`StringSlice`): The pooling mode "AVG" for average and "MAX" for max pooling.
--- ## apply_rope
`apply_rope[dtype: DType, freq_dtype: DType, x_layout: Layout, rank: Int, width: Int, //, *, interleaved: Bool, alignment: Int, output_fn: fn[width: Int, alignment: Int](idx: IndexList[rank], val: SIMD[dtype, width]) capturing -> None](x: LayoutTensor[dtype, x_layout, MutAnyOrigin], idx: IndexList[rank], freq_val: SIMD[freq_dtype, width])`
--- ## get_identity_rope_coeff (Rope)
`get_identity_rope_coeff[width: Int, dtype: DType]() -> SIMD[dtype, width]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## get_safetensors_idx (Rope)
`get_safetensors_idx(head_dim_idx: Int, head_size: Int) -> Tuple[Int, Int]` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)
--- ## rope
## Functions * [​`apply_rope`](./apply_rope): * [​`get_identity_rope_coeff`](./get_identity_rope_coeff): * [​`get_safetensors_idx`](./get_safetensors_idx): * [​`rope_ragged`](./rope_ragged):
--- ## rope_ragged
`rope_ragged[dtype: DType, x_layout: Layout, freq_dtype: DType, input_row_offsets_layout: Layout, start_pos_layout: Layout, freqs_cis_layout: Layout, *, interleaved: Bool, target: StringSlice[StaticConstantOrigin], output_fn: fn[width: Int, alignment: Int](idx: IndexList[3], val: SIMD[dtype, width]) capturing -> None, mrope_section: Optional[IntTuple] = None](x: LayoutTensor[dtype, x_layout, MutAnyOrigin], input_row_offsets: LayoutTensor[DType.uint32, input_row_offsets_layout, MutAnyOrigin], start_pos: LayoutTensor[DType.uint32, start_pos_layout, MutAnyOrigin], freqs_cis: LayoutTensor[freq_dtype, freqs_cis_layout, MutAnyOrigin], context: Optional[DeviceContext], position_ids: OptionalReg[LayoutTensor[DType.uint32, Layout.row_major(-1, -1), MutAnyOrigin]] = None)`
--- ## apply_penalties_to_logits
`apply_penalties_to_logits[logit_type: DType, penalty_type: DType, //, target: StringSlice[StaticConstantOrigin]](logits: LayoutTensor[logit_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], compressed_frequency_data: LayoutTensor[DType.int32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], frequency_offsets: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], frequency_penalty: LayoutTensor[penalty_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], presence_penalty: LayoutTensor[penalty_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], repetition_penalty: LayoutTensor[penalty_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Apply penalties to the logits based on the frequency of the tokens in the batch. The frequency data is stored in a CSR format, where the frequency\_offsets is the starting index of each sequence in the frequency\_data array. The frequency\_data array is a 2D array, where: * frequency\_data\[i, 0] is the token id * frequency\_data\[i, 1] is the frequency of the token in the sequence
--- ## sampling
## Functions * [​`apply_penalties_to_logits`](./apply_penalties_to_logits): Apply penalties to the logits based on the frequency of the tokens in the batch. * [​`update_frequency_data`](./update_frequency_data): Update the frequency data for the given new tokens. * [​`update_frequency_data_kernel`](./update_frequency_data_kernel): GPU kernel to update token frequency data in CSR format.
--- ## update_frequency_data
`update_frequency_data[token_type: DType, //, target: StringSlice[StaticConstantOrigin]](compressed_frequency_data: LayoutTensor[DType.int32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], frequency_offsets: LayoutTensor[DType.uint32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], new_tokens: LayoutTensor[token_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr)` Update the frequency data for the given new tokens. The frequency data is stored in a CSR format. This kernel expects there will be enough padding for each sequence to store the new tokens.
--- ## update_frequency_data_kernel
`update_frequency_data_kernel[token_type: DType, block_size: Int, freq_data_layout: Layout, freq_offsets_layout: Layout, new_tokens_layout: Layout](compressed_frequency_data: LayoutTensor[DType.int32, freq_data_layout, MutAnyOrigin], frequency_offsets: LayoutTensor[DType.uint32, freq_offsets_layout, MutAnyOrigin], new_tokens: LayoutTensor[token_type, new_tokens_layout, MutAnyOrigin])` GPU kernel to update token frequency data in CSR format. Searches for new tokens in existing frequency data and either increments their count or adds them to the first available padding slot.
--- ## get_sliding_window_out_dim
`get_sliding_window_out_dim[ceil_mode: Bool = False](in_dim: Int, ft_dim: Int, dilation: Int, stride: Int, pad: Int) -> Int` Return output dimension for a sliding window operation along some dimension. **Parameters:** * ​ceil\_mode ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Define rounding mode for shape calculation. **Args:** * ​in\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the input dimension. * ​ft\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the corresponding filter dimension. * ​dilation ([`Int`](/mojo/stdlib/builtin/int/Int)): The dilation for the sliding window operation. * ​stride ([`Int`](/mojo/stdlib/builtin/int/Int)): The stride for the sliding window operation. * ​pad ([`Int`](/mojo/stdlib/builtin/int/Int)): The total padding for the sliding window operation. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the output dimension.
--- ## shapes
## Functions * [​`get_sliding_window_out_dim`](./get_sliding_window_out_dim): Return output dimension for a sliding window operation along some dimension.
--- ## copy_to_slice
`copy_to_slice[dtype: DType, start_type: DType, end_type: DType, step_type: DType, target: StringSlice[StaticConstantOrigin] = "cpu"](buffer: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], in_slice: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], start: LayoutTensor[start_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], end: LayoutTensor[end_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], step: LayoutTensor[step_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], context: DeviceContextPtr = DeviceContextPtr())`
--- ## slice
## Functions * [​`copy_to_slice`](./copy_to_slice): * [​`slice_as_copy`](./slice_as_copy): * [​`slice_as_view`](./slice_as_view): * [​`slice_dim_as_view`](./slice_dim_as_view): * [​`slice_shape`](./slice_shape): * [​`sliced_add`](./sliced_add): Adds tensors a and b element-wise for rows < lora\_end\_idx, otherwise copies a.
--- ## slice_as_copy
`slice_as_copy[dtype: DType, index_type: DType](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], start: LayoutTensor[index_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], end: LayoutTensor[index_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], step: LayoutTensor[index_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## slice_as_view
`slice_as_view[dtype: DType, start_type: DType, end_type: DType, step_type: DType](tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], starts: LayoutTensor[start_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ends: LayoutTensor[end_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], steps: LayoutTensor[step_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> LayoutTensor[dtype, Layout.row_major[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank](), origin, address_space=address_space]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## slice_dim_as_view
`slice_dim_as_view[dtype: DType, dim: Int](tensor: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], start: Int, end: Int, step: Int) -> LayoutTensor[dtype, Layout.row_major[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank](), origin, address_space=address_space]` **Returns:** [`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)
--- ## slice_shape
`slice_shape[input_type: DType, start_type: DType, stop_type: DType, step_type: DType, single_thread_blocking_override: Bool](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], start_buf: LayoutTensor[start_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], stop_buf: LayoutTensor[stop_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], step_buf: LayoutTensor[step_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList)
--- ## sliced_add
`sliced_add[dtype: DType, //, target: StringSlice[StaticConstantOrigin]](c: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], lora_end_idx: LayoutTensor[DType.int64, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: Optional[DeviceContext])` Adds tensors a and b element-wise for rows < lora\_end\_idx, otherwise copies a. This is used for LoRA where only some sequences have LoRA applied. For rows in \[0, lora\_end\_idx): c = a + b For rows in \[lora\_end\_idx, batch\_seq\_len): c = a **Args:** * ​c ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor. * ​a ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): First input tensor. * ​b ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Second input tensor. * ​lora\_end\_idx ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Scalar tensor with end index of LoRA token portion (rows to apply add). * ​ctx ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Device context for GPU operations.
--- ## identity
`identity(x: SIMD[dtype, size]) -> SIMD[dtype, size]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## softmax
## Functions * [​`identity`](./identity): * [​`logsoftmax`](./logsoftmax): * [​`mul`](./mul): * [​`reciprocal`](./reciprocal): * [​`reduce_add_simd`](./reduce_add_simd): This functions adds val to either the scalar value or the vector value depending on the step\_simd\_width. This is useful when the simd\_width varies between iterations as in vectorize. * [​`softmax`](./softmax): * [​`softmax_2_pass`](./softmax_2_pass): Performs an unbatched softmax on an input tensor using the two-pass online algorithm. * [​`softmax_3_pass`](./softmax_3_pass): Performs an unbatched softmax on an input tensor using the three-pass algorithm. * [​`softmax_kernel`](./softmax_kernel): * [​`sub`](./sub):
--- ## logsoftmax
`logsoftmax[dtype: DType, simd_width: Int, rank: Int, input_fn: fn[_simd_width: Int, _rank: Int](IndexList[_rank]) capturing -> SIMD[dtype, _simd_width], target: StringSlice[StaticConstantOrigin] = "cpu"](shape: IndexList[rank], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int, context: DeviceContextPtr = DeviceContextPtr())` `logsoftmax[dtype: DType, simd_width: Int, rank: Int, target: StringSlice[StaticConstantOrigin] = "cpu"](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int, context: DeviceContextPtr = DeviceContextPtr())`
--- ## mul (Softmax)
`mul(x: SIMD[dtype, size], y: SIMD[dtype, size]) -> SIMD[dtype, size]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## reciprocal
`reciprocal(x: SIMD[dtype, size]) -> SIMD[dtype, size]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## reduce_add_simd
`reduce_add_simd[simd_width: Int, step_simd_width: Int, dtype: DType](mut scalar: Scalar[dtype], mut vector: SIMD[dtype, simd_width], val: SIMD[dtype, step_simd_width])` This functions adds val to either the scalar value or the vector value depending on the step\_simd\_width. This is useful when the simd\_width varies between iterations as in vectorize.
--- ## softmax (Softmax)
`softmax[dtype: DType, simd_width: Int, rank: Int](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int)` `softmax[dtype: DType, simd_width: Int, rank: Int, input_fn: fn[_simd_width: Int, _rank: Int](IndexList[_rank]) capturing -> SIMD[dtype, _simd_width], target: StringSlice[StaticConstantOrigin] = "cpu", logsoftmax: Bool = False](shape: IndexList[rank], output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int, context: DeviceContextPtr = DeviceContextPtr())`
--- ## softmax_2_pass
`softmax_2_pass[simd_width: Int, dtype: DType](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Performs an unbatched softmax on an input tensor using the two-pass online algorithm. The unbatched two-pass online softmax is described in "Online normalizer calculation for softmax" () and "A full-stack search technique for domain optimized deep learning accelerators" () and is defined as: ``` procedure SoftmaxUnbatched(InputInput) runningMax = -∞ runningSum = 0 STAGE 1: for i = 0 to N do newMax = max(runningMax, Input[i]) runningSum = runningSum*exp(runningMax-newMax) + exp(Input[i]-newMax) runningMax = newMax end for for i = 0 to N do Output[i] = exp(Input[i] - runningMax) / runningSum end for ``` **Parameters:** * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The simd\_width to use in vectorization. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input and output buffers. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer in which to store the softmax values. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input buffer used to compute the softmax.
--- ## softmax_3_pass
`softmax_3_pass[simd_width: Int, dtype: DType, origins: OriginSet, input_fn_1d: fn[_simd_width: Int](Int) capturing -> SIMD[dtype, _simd_width], logsoftmax: Bool = False](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Performs an unbatched softmax on an input tensor using the three-pass algorithm. The unbatched three-pass softmax is defined as: ``` procedure SoftmaxUnbatched(InputInput) maxVal = -∞ denom = 0 STEP 1: find the max value in each batch for i = 0 to N do maxVal = max(maxVal, Input[b, i]) end for STEP 2: compute the exponential for each batch for i = 0 to N do Output[b, i] = exp(Input[b, i] - maxVal) denom += Output[b, i] end for STEP 3: normalize each batch for i = 0 to N do Output[b, i] /= denom end for ``` **Parameters:** * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The simd\_width to use in vectorization. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input and output buffers. * ​origins ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): The OriginSet of captured arguments by the input\_fn\_1d. * ​input\_fn\_1d (`fn[_simd_width: Int](Int) capturing -> SIMD[dtype, _simd_width]`): The elementwise input lambda. * ​logsoftmax ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Enable to apply elementwise log() to outputs after softmax. **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output buffer in which to store the softmax values.
--- ## softmax_kernel
`softmax_kernel[BLOCK_SIZE: Int, input_fn: fn[_dtype: DType, _simd_width: Int, _rank: Int](IndexList[_rank]) capturing -> SIMD[_dtype, _simd_width], dtype: DType, layout: Layout, sink_type: DType, rank: Int, accum_type: DType = get_accum_type[dtype](), *, sink: Bool = False, logsoftmax: Bool = False](shape: IndexList[rank], output: LayoutTensor[dtype, layout, MutAnyOrigin], sink_weights: LayoutTensor[sink_type, Layout.row_major(-1), MutAnyOrigin])`
--- ## sub
`sub(x: SIMD[dtype, size], y: SIMD[dtype, size]) -> SIMD[dtype, size]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## spatial_merge
## Functions * [​`spatial_merge`](./spatial_merge): * [​`spatial_merge_kernel`](./spatial_merge_kernel): Spatial merge kernel.
--- ## spatial_merge (Spatial_merge)
`spatial_merge[dtype: DType](output: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], grid_thw: LayoutTensor[DType.int64, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], hidden_size: Int, merge_size: Int, ctx: DeviceContext)`
--- ## spatial_merge_kernel
`spatial_merge_kernel[dtype: DType, input_layout: Layout, output_layout: Layout, grid_thw_layout: Layout](output: LayoutTensor[dtype, output_layout, MutAnyOrigin], input: LayoutTensor[dtype, input_layout, MutAnyOrigin], grid_thw: LayoutTensor[DType.int64, grid_thw_layout, MutAnyOrigin], batch_size: Int, hidden_size: Int, merge_size: Int)` Spatial merge kernel. Grid: 1D over all output patches (one block per output patch). Threads: loop over channels (hidden\_size x merge\_size^2). **Args:** * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input tensor. * ​grid\_thw ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Grid dimensions tensor (B, 3) containing \[t, h, w] for each item. * ​batch\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of items in batch. * ​hidden\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Hidden dimension size. * ​merge\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Size of spatial merge blocks.
--- ## split
## Functions * [​`split`](./split):
--- ## split (Split)
`split[dtype: DType, num_outputs: Int, target: StringSlice[StaticConstantOrigin], trace_description: StringSlice[StaticConstantOrigin], outputs_origin: MutOrigin, outputs_layout: Layout](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], axis: Int, outputs: StaticTuple[LayoutTensor[dtype, outputs_layout, outputs_origin], num_outputs], ctx: DeviceContext)`
--- ## tile
## Functions * [​`tile`](./tile): Implements the `Tile` operator from the ONNX spec. This behaves like Numpy tile, but without broadcast. * [​`tile_shape`](./tile_shape): Compute the output shape of a `tile` operation, and assert the inputs are compatible.
--- ## tile (Tile)
`tile[dtype: DType, type_repeats: DType](input: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], repeats: LayoutTensor[type_repeats, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])` Implements the `Tile` operator from the ONNX spec. This behaves like Numpy tile, but without broadcast. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input and output tensors. * ​type\_repeats ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the repeats tensor. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. Currently <= 4 dimensions are supported. * ​repeats ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): One-dimensional tensor that specifies the number of repeated copies along each of the input's dimensions. Length equals input tensor rank. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor. Has the same dimensions and type as input.
--- ## tile_shape
`tile_shape[input_type: DType, repeats_type: DType, single_thread_blocking_override: Bool](input_buf: LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], repeats_buf: LayoutTensor[repeats_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]) -> IndexList[LayoutTensor[input_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a `tile` operation, and assert the inputs are compatible. **Parameters:** * ​input\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the input tensor. * ​repeats\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the repeats tensor. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. **Args:** * ​input\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​repeats\_buf ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The repeats tensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## TopK_2
`@register_passable(trivial)` `struct TopK_2[T: DType, largest: Bool = True]` ## Fields * ​p (`Int`): * ​u (`Scalar[T]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` ### `insert` `insert(mut self, elem: Scalar[T], elem_id: Int)`
--- ## apply_gumbel_noise_kernel
`apply_gumbel_noise_kernel[dtype: DType, input_layout: Layout, num_sms: Int, num_threads: Int](output: LayoutTensor[dtype, input_layout, MutAnyOrigin], input: LayoutTensor[dtype, input_layout, MutAnyOrigin], temperature: LegacyUnsafePointer[Float32], seed: LegacyUnsafePointer[UInt64])`
--- ## fused_token_sampling_cpu
`fused_token_sampling_cpu[dtype: DType, out_idx_type: DType](max_k: Int, input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_idxs: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: OptionalReg[LayoutTensor[DType.int64, Layout.row_major(-1), MutAnyOrigin]] = None, temperature: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, top_p: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, seed: OptionalReg[LayoutTensor[DType.uint64, Layout.row_major(-1), MutAnyOrigin]] = None)` Generalized implementation of the Top K algorithm with sampling. Returns the sampled index from the innermost dimension of the input tensor for each row/subvolume. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the input buffer. * ​out\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the output indices. **Args:** * ​max\_k ([`Int`](/mojo/stdlib/builtin/int/Int)): Largest number of top elements. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): NDBuffer\[dtype, rank] (Any shape)- The input tensor. * ​out\_idxs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): NDBuffer\[out\_idx\_type, rank] (shape of \[input\_shape\[:-1]] + \[1]) - The output indices. * ​k ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional device buffer of top elements to keep for each batch element. * ​temperature ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The temperature based scaling. * ​top\_p ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Only use the tokens whose cumulative probability exceeds this threshold. * ​seed ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The seed to use for the random number generator.
--- ## fused_token_sampling_gpu
`fused_token_sampling_gpu[dtype: DType, out_idx_type: DType, //](ctx: DeviceContext, max_k: Int, min_top_p: Float32, input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_idxs: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], block_size: OptionalReg[Int] = None, num_blocks_per_input: OptionalReg[Int] = None, k: OptionalReg[LayoutTensor[DType.int64, Layout.row_major(-1), MutAnyOrigin]] = None, temperature: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, top_p: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, seed: OptionalReg[LayoutTensor[DType.uint64, Layout.row_major(-1), MutAnyOrigin]] = None)` Top K algorithm with fused sampling. Returns the sampled indices from the Top-K of the innermost dimension of the input tensor for each row/subvolume.
--- ## gumbel_sampling_gpu
`gumbel_sampling_gpu[dtype: DType, out_idx_type: DType, input_layout: Layout, //](ctx: DeviceContext, input: LayoutTensor[dtype, input_layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_idxs: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], temperature: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, seed: OptionalReg[LayoutTensor[DType.uint64, Layout.row_major(-1), MutAnyOrigin]] = None)` Gumbel sampling using the Gumbel-max trick for categorical distributions. Applies Gumbel(0,1) noise to input logits, then selects the argmax. This is mathematically equivalent to sampling from softmax(logits/temperature) but avoids expensive softmax computation. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): Device context for GPU operations. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input logits tensor \[batch, vocab\_size]. * ​out\_idxs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output tensor for sampled indices \[batch, 1]. * ​temperature ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional per-token temperature scaling \[batch]. * ​seed ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional per-token random seeds \[batch] for reproducibility.
--- ## topk
## Structs * [​`TopK_2`](./TopK_2): ## Functions * [​`apply_gumbel_noise_kernel`](./apply_gumbel_noise_kernel): * [​`fused_token_sampling_cpu`](./fused_token_sampling_cpu): Generalized implementation of the Top K algorithm with sampling. Returns the sampled index from the innermost dimension of the input tensor for each row/subvolume. * [​`fused_token_sampling_gpu`](./fused_token_sampling_gpu): Top K algorithm with fused sampling. Returns the sampled indices from the Top-K of the innermost dimension of the input tensor for each row/subvolume. * [​`gumbel_sampling_gpu`](./gumbel_sampling_gpu): Gumbel sampling using the Gumbel-max trick for categorical distributions. * [​`top_k`](./top_k): Implementation of the Top K algorithm. Returns the top or bottom K elements and their index along a specified axis. * [​`top_k_shape_impl`](./top_k_shape_impl): Compute the output shape of a top/bottom k operation. * [​`topk_gpu`](./topk_gpu): Generalized implementation of the Top K algorithm with/without sampling. Returns the sampled index from the innermost dimension of the input tensor for each row/subvolume or the top K values and indices across the tensor.
--- ## top_k
`top_k[dtype: DType, out_idx_type: DType, //, largest: Bool = True, target: StringSlice[StaticConstantOrigin] = "cpu"](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_k: Int, axis: Int, out_vals: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_idxs: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sorted: Bool, ctx: DeviceContextPtr, k: OptionalReg[LayoutTensor[DType.int64, Layout.row_major(-1), MutAnyOrigin]] = None)` Implementation of the Top K algorithm. Returns the top or bottom K elements and their index along a specified axis. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the input buffer. * ​out\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data dtype of the output indices (default is DType.int64). * ​largest ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to find the maximum (top k) or minimum value (bottom k). * ​target (`StringSlice`): The target to run on. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​max\_k ([`Int`](/mojo/stdlib/builtin/int/Int)): The largest number of top elements. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis along which to operate. * ​out\_vals ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output values. * ​out\_idxs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output indices. * ​sorted ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Indicates if the top/bottom K elements are in (stable) sorted order. * ​ctx ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The device call context. * ​k ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Per batch element k value.
--- ## top_k_shape_impl
`top_k_shape_impl[dtype: DType, single_thread_blocking_override: Bool](input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_k: Int, axis: Int) -> IndexList[LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank]` Compute the output shape of a top/bottom k operation. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the input buffer. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If this function can block. **Args:** * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor. * ​max\_k ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum K value. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis value in a tensor. **Returns:** [`IndexList`](/mojo/stdlib/utils/index_/IndexList): The output shape.
--- ## topk_gpu
`topk_gpu[dtype: DType, out_idx_type: DType, //, sampling: Bool = True, largest: Bool = True, _force_old_impl: Bool = False](ctx: DeviceContext, max_k: Int, input: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_vals: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_idxs: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], block_size: OptionalReg[Int] = None, num_blocks_per_input: OptionalReg[Int] = None, k: OptionalReg[LayoutTensor[DType.int64, Layout.row_major(-1), MutAnyOrigin]] = None, temperature: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, top_p: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, seed: OptionalReg[LayoutTensor[DType.uint64, Layout.row_major(-1), MutAnyOrigin]] = None)` Generalized implementation of the Top K algorithm with/without sampling. Returns the sampled index from the innermost dimension of the input tensor for each row/subvolume or the top K values and indices across the tensor. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - The data dtype of the input tensor. * ​out\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - The data dtype of the output indices (default is DType.int). * ​sampling ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - Whether to return token samples from topK dist (default is True). * ​largest ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - Whether to find the maximum or minimum value. * ​\_force\_old\_impl ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - Whether to force use the old implementation. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): DeviceContext The context for GPU execution. * ​max\_k ([`Int`](/mojo/stdlib/builtin/int/Int)): Int Largest number of top elements to keep for each batch element. * ​input ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): NDBuffer\[dtype, rank] Input tensor as a device NDBuffer. * ​out\_vals ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): NDBuffer\[dtype, rank] Output buffer on device for the K largest values. * ​out\_idxs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): NDBuffer\[DType.int, rank] Output buffer on device for the indices of the K largest values, or sampled token indices. Last dimension is 1 if sampling is True, otherwise K. * ​block\_size ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Int The number of threads per block (default is 256 from TRT and empirical testing). * ​num\_blocks\_per\_input ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): OptionalReg\[Int] Number of blocks per input (default computed from input size and block size). This is the equivalent of "BLOCKS\_PER\_BEAM" in TRT-LLM kernel allowing for much larger batch sizes through packing several elements per thread in the first stage. * ​k ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional NDBuffer\[DType.int64, 1, MutAnyOrigin] Device buffer of top elements to keep for each batch element. * ​temperature ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The temperature based scaling. * ​top\_p ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Only use the tokens whose cumulative probability exceeds this threshold. * ​seed ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The seed to use for the random number generator.
--- ## TopKMaskLogitsKernel
`TopKMaskLogitsKernel[block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, logits_layout: Layout, masked_logits_layout: Layout](logits: LayoutTensor[dtype, logits_layout, MutAnyOrigin], masked_logits: LayoutTensor[dtype, masked_logits_layout, MutAnyOrigin], top_k_arr: LegacyUnsafePointer[Scalar[out_idx_type]], top_k_val: Int, d: Int)`
--- ## TopKSamplingFromProbKernel
`TopKSamplingFromProbKernel[block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, probs_layout: Layout, output_layout: Layout, deterministic: Bool](probs: LayoutTensor[dtype, probs_layout, MutAnyOrigin], output: LayoutTensor[out_idx_type, output_layout, MutAnyOrigin], indices: LegacyUnsafePointer[Scalar[out_idx_type]], top_k_arr: LegacyUnsafePointer[Scalar[out_idx_type]], top_k_val: Int, d: Int, rng_seed: UInt64, rng_offset: UInt64)` Kernel for top-k sampling from probability distribution. This kernel performs top-k sampling by: 1. Using ternary search to find a pivot threshold. 2. Rejecting samples iteratively until acceptance criteria is met. 3. Sampling an index using uniform random numbers from Random generator. **Args:** * ​probs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input probability distribution \[batch\_size, d]. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output sampled indices \[batch\_size]. * ​indices (`LegacyUnsafePointer`): Optional row indices for batch indexing \[batch\_size]. * ​top\_k\_arr (`LegacyUnsafePointer`): Optional per-row top\_k values \[batch\_size]. * ​top\_k\_val ([`Int`](/mojo/stdlib/builtin/int/Int)): Default top\_k value if top\_k\_arr is null. * ​d ([`Int`](/mojo/stdlib/builtin/int/Int)): Vocabulary size. * ​rng\_seed ([`UInt64`](/mojo/stdlib/builtin/simd/#uint64)): Random seed for Random number generator. * ​rng\_offset ([`UInt64`](/mojo/stdlib/builtin/simd/#uint64)): Random offset for Random number generator.
--- ## TopKSoftmaxSampleKernel
`TopKSoftmaxSampleKernel[block_size: Int, vec_size: Int, dtype: DType, out_idx_type: DType, logits_layout: Layout, sampled_indices_layout: Layout](logits: LayoutTensor[dtype, logits_layout, MutAnyOrigin], sampled_indices: LayoutTensor[out_idx_type, sampled_indices_layout, MutAnyOrigin], top_k_arr: LegacyUnsafePointer[Scalar[out_idx_type]], top_k_val: Int, temperature_val: Float32, temperature: LegacyUnsafePointer[Float32], seed_val: UInt64, seed: LegacyUnsafePointer[UInt64], d: Int)`
--- ## ValueCount
`@register_passable(trivial)` `struct ValueCount[T: DType]` A struct that holds a value and a count, used for block reductions. This is useful for computing both the sum of values and the count of elements that satisfy a condition in a single reduction pass. ## Parameters * ​T ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType of the value field. ## Fields * ​value (`Scalar[T]`): * ​count (`Int32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(value: Scalar[T], count: Int32) -> Self` `__init__() -> Self` ### `__add__` `__add__(self, other: Self) -> Self` ### `__iadd__` `__iadd__(mut self, other: Self)`
--- ## device_sampling_from_prob
`device_sampling_from_prob[vec_size: Int, block_size: Int, dtype: DType, deterministic: Bool = False](i: Int, d: Int, low: Float64, u: Float32, prob_vec: SIMD[DType.float32, vec_size], aggregate: Float32, sampled_id_sram: LegacyUnsafePointer[Int, address_space=AddressSpace.SHARED], last_valid_id_sram: LegacyUnsafePointer[Int, address_space=AddressSpace.SHARED]) -> Float32` Device-level sampling from probability distribution with atomic operations. **Returns:** [`Float32`](/mojo/stdlib/builtin/simd/#float32)
--- ## get_min_max_value
`get_min_max_value[vec_size: Int, block_size: Int, dtype: DType](in_data: LegacyUnsafePointer[Scalar[dtype]], row_idx: Int, d: Int) -> Tuple[Float32, Float32]` Compute the minimum and maximum values from input data using block reduction. **Parameters:** * ​vec\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements each thread processes per iteration (vectorization width). * ​block\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of threads per block. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input data. **Args:** * ​in\_data (`LegacyUnsafePointer`): Pointer to input data buffer. * ​row\_idx ([`Int`](/mojo/stdlib/builtin/int/Int)): Row index for the current block (for 2D data access). * ​d ([`Int`](/mojo/stdlib/builtin/int/Int)): Total number of elements in the row. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Tuple containing \[min\_val, max\_val].
--- ## topk_fi
## Structs * [​`ValueCount`](./ValueCount): A struct that holds a value and a count, used for block reductions. ## Functions * [​`device_sampling_from_prob`](./device_sampling_from_prob): Device-level sampling from probability distribution with atomic operations. * [​`get_min_max_value`](./get_min_max_value): Compute the minimum and maximum values from input data using block reduction. * [​`topk_mask_logits`](./topk_mask_logits): * [​`topk_sampling_from_prob`](./topk_sampling_from_prob): Top-K sampling from probability distribution. * [​`topk_softmax_sample`](./topk_softmax_sample): Samples token indices from top-K logits using softmax probabilities. * [​`TopKMaskLogitsKernel`](./TopKMaskLogitsKernel): * [​`TopKSamplingFromProbKernel`](./TopKSamplingFromProbKernel): Kernel for top-k sampling from probability distribution. * [​`TopKSoftmaxSampleKernel`](./TopKSoftmaxSampleKernel):
--- ## topk_mask_logits
`topk_mask_logits[dtype: DType, out_idx_type: DType, block_size: Int = 1024](ctx: DeviceContext, logits: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], masked_logits: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], top_k_val: Int, top_k_arr: OptionalReg[LayoutTensor[out_idx_type, Layout.row_major(-1), MutAnyOrigin]] = None)`
--- ## topk_sampling_from_prob
`topk_sampling_from_prob[dtype: DType, out_idx_type: DType, block_size: Int = 1024](ctx: DeviceContext, probs: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], top_k_val: Int, deterministic: Bool = False, rng_seed: UInt64 = 0, rng_offset: UInt64 = 0, indices: OptionalReg[LayoutTensor[out_idx_type, Layout.row_major(-1), MutAnyOrigin]] = None, top_k_arr: OptionalReg[LayoutTensor[out_idx_type, Layout.row_major(-1), MutAnyOrigin]] = None)` Top-K sampling from probability distribution. Performs stochastic sampling from a probability distribution, considering only the top-k most probable tokens. Uses rejection sampling with ternary search to efficiently find appropriate samples. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): Device context for kernel execution. * ​probs ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input probability distribution \[batch\_size, d]. * ​output ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output sampled indices \[batch\_size]. * ​top\_k\_val ([`Int`](/mojo/stdlib/builtin/int/Int)): Default top-k value (number of top tokens to consider). * ​deterministic ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to use deterministic sampling. * ​rng\_seed ([`UInt64`](/mojo/stdlib/builtin/simd/#uint64)): Random seed for Random number generator. * ​rng\_offset ([`UInt64`](/mojo/stdlib/builtin/simd/#uint64)): Random offset for Random number generator. * ​indices ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional row indices for batch indexing \[batch\_size]. * ​top\_k\_arr ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional per-row top-k values \[batch\_size]. **Raises:** Error: If tensor ranks or shapes are invalid.
--- ## topk_softmax_sample
`topk_softmax_sample[dtype: DType, out_idx_type: DType, block_size: Int = 1024](ctx: DeviceContext, logits: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sampled_indices: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], top_k_val: Int, temperature_val: Float32 = 1, seed_val: UInt64 = 0, top_k_arr: OptionalReg[LayoutTensor[out_idx_type, Layout.row_major(-1), MutAnyOrigin]] = None, temperature: OptionalReg[LayoutTensor[DType.float32, Layout.row_major(-1), MutAnyOrigin]] = None, seed: OptionalReg[LayoutTensor[DType.uint64, Layout.row_major(-1), MutAnyOrigin]] = None)` Samples token indices from top-K logits using softmax probabilities. This kernel performs single-pass top-K selection and categorical sampling: 1. Finds the k-th largest logit via ternary search. 2. Computes softmax over top-K elements and caches them in shared memory. 3. Samples a single token index from the categorical distribution. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the input logits tensor. * ​out\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the output sampled indices. * ​block\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of threads per block (default is 1024). **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): DeviceContext The context for GPU execution. * ​logits ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Input logits tensor with shape \[batch\_size, vocab\_size]. * ​sampled\_indices ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): Output buffer for sampled token indices with shape \[batch\_size]. * ​top\_k\_val ([`Int`](/mojo/stdlib/builtin/int/Int)): Int Default number of top elements to sample from for each batch element. * ​temperature\_val ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): Float32 Temperature for softmax scaling (default is 1.0). * ​seed\_val ([`UInt64`](/mojo/stdlib/builtin/simd/#uint64)): UInt64 Seed for the random number generator (default is 0). * ​top\_k\_arr ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional per-batch top-K values. If provided, overrides top\_k\_val for each batch element. * ​temperature ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional per-batch temperature values. If provided, overrides temperature\_val for each batch element. * ​seed ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional per-batch seed values. If provided, overrides seed\_val for each batch element.
--- ## toppminp
## Functions * [​`merge`](./merge): Merge two sorted subarrays into one sorted array. * [​`merge_sort_recursive`](./merge_sort_recursive): Recursive merge sort implementation. * [​`min_p_sampling`](./min_p_sampling): Naive CPU implementation of Min-P sampling for token selection. This function applies temperature scaling, softmax, a merge sort, and then samples tokens based on the calculated probability threshold (Min-P). * [​`sort_buf_descending`](./sort_buf_descending): Sort each batch separately in descending order using parallel merge sort. * [​`top_p_sampling`](./top_p_sampling): Naive CPU implementation of Top-P sampling for token selection. This function applies temperature scaling, softmax, a merge sort, and then samples tokens based on the cumulative probability mass (Top-P).
--- ## merge
`merge[dtype: DType, out_idx_type: DType](mut buf_keys: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mut buf_ids: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], start: Int, mid: Int, end: Int)` Merge two sorted subarrays into one sorted array.
--- ## merge_sort_recursive
`merge_sort_recursive[dtype: DType, out_idx_type: DType](mut buf_keys: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mut buf_ids: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], start: Int, end: Int)` Recursive merge sort implementation.
--- ## min_p_sampling
`min_p_sampling[dtype: DType, out_idx_type: DType, //, _test_sort: Bool = False](min_ps: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_logits: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_token_ids: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], temperature: Scalar[dtype] = 1)` Naive CPU implementation of Min-P sampling for token selection. This function applies temperature scaling, softmax, a merge sort, and then samples tokens based on the calculated probability threshold (Min-P).
--- ## sort_buf_descending
`sort_buf_descending[dtype: DType, out_idx_type: DType](mut buf_keys: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], mut buf_ids: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], vocab_size: Int)` Sort each batch separately in descending order using parallel merge sort.
--- ## top_p_sampling
`top_p_sampling[dtype: DType, out_idx_type: DType, //, _test_sort: Bool = False](top_ps: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_logits: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_token_ids: LayoutTensor[out_idx_type, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], temperature: Scalar[dtype] = 1)` Naive CPU implementation of Top-P sampling for token selection. This function applies temperature scaling, softmax, a merge sort, and then samples tokens based on the cumulative probability mass (Top-P).
--- ## DoubleBuffer
`struct DoubleBuffer[dtype: DType]` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self)` `__init__(out self, current: LegacyUnsafePointer[Scalar[dtype]], alternate: LegacyUnsafePointer[Scalar[dtype]], size: Int)` ### `current` `current(self, ctx: DeviceContext) -> DeviceBuffer[dtype]` **Returns:** [`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer) ### `alternate` `alternate(self, ctx: DeviceContext) -> DeviceBuffer[dtype]` **Returns:** [`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer) ### `swap` `swap(mut self)`
--- ## toppminp_gpu
## `comptime` values ### `DEBUG_FILE` `comptime DEBUG_FILE = False` ### `SEED` `comptime SEED = 42` ## Structs * [​`DoubleBuffer`](./DoubleBuffer): ## Functions * [​`min_p_sampling_gpu`](./min_p_sampling_gpu): GPU implementation of Min-P sampling for token selection. This function applies temperature scaling, softmax, a radix sort, and then samples tokens based on the calculated probability threshold (Min-P). * [​`normalize`](./normalize): * [​`normalize_u32`](./normalize_u32): * [​`radix_sort_pairs_kernel`](./radix_sort_pairs_kernel): Radix pair sort kernel for (default) descending order. * [​`run_radix_sort_pairs_gpu`](./run_radix_sort_pairs_gpu): * [​`top_p_sampling_gpu`](./top_p_sampling_gpu): GPU implementation of Top-P sampling for token selection. This function applies temperature scaling, softmax, a radix sort, and then samples tokens based on the cumulative probability mass (Top-P). * [​`topk_wrapper`](./topk_wrapper): Copy of `Kernels/mojo/nn/topk.mojo:_topk_stage1` with the addition of max\_vals and p\_threshold arguments to determine if sorting is needed for top-p/min-p sampling. * [​`topp_minp_sampling_kernel`](./topp_minp_sampling_kernel): Top P-Min P sampling kernel.
--- ## min_p_sampling_gpu
`min_p_sampling_gpu[dtype: DType, out_idx_type: DType, //, _test_sort: Bool = False](ctx: DeviceContext, min_ps: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_logits: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_token_ids: LayoutTensor[out_idx_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], temperature: Scalar[dtype] = 1)` GPU implementation of Min-P sampling for token selection. This function applies temperature scaling, softmax, a radix sort, and then samples tokens based on the calculated probability threshold (Min-P).
--- ## normalize
`normalize(value: BFloat16) -> UInt16` **Returns:** [`UInt16`](/mojo/stdlib/builtin/simd/#uint16) `normalize(value: Int32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) `normalize(value: UInt16) -> UInt16` **Returns:** [`UInt16`](/mojo/stdlib/builtin/simd/#uint16) `normalize(value: Float32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32) `normalize(value: Scalar[dtype]) -> Scalar[_uint_type_of_width[bit_width_of[dtype]()]()]` Normalize the value to the appropriate unsigned integer type. This is needed for radix sort to work correctly. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar)
--- ## normalize_u32
`normalize_u32(value: UInt32) -> UInt32` **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32)
--- ## radix_sort_pairs_kernel
`radix_sort_pairs_kernel[dtype: DType, out_idx_type: DType, current_bit: Int, ascending: Bool = False, BLOCK_SIZE: Int = 256, NUM_BITS_PER_PASS: Int = 4](input_keys_: LegacyUnsafePointer[Scalar[dtype]], output_keys_: LegacyUnsafePointer[Scalar[dtype]], input_key_ids_: LegacyUnsafePointer[Scalar[out_idx_type]], output_key_ids_: LegacyUnsafePointer[Scalar[out_idx_type]], num_keys: Int, skip_sort: LegacyUnsafePointer[Scalar[DType.bool]])` Radix pair sort kernel for (default) descending order. Implementation based on: AMD. Introduction to GPU Radix Sort. GPUOpen, 2017. Available at: . **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - Data type. * ​out\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - Output index type. * ​current\_bit ([`Int`](/mojo/stdlib/builtin/int/Int)): Int - Current bit to start sorting NUM\_BITS\_PER\_PASS bits at. * ​ascending ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - Whether to sort in ascending order. * ​BLOCK\_SIZE ([`Int`](/mojo/stdlib/builtin/int/Int)): Int - Block size. * ​NUM\_BITS\_PER\_PASS ([`Int`](/mojo/stdlib/builtin/int/Int)): Int - Number of bits per pass. **Args:** * ​input\_keys\_ (`LegacyUnsafePointer`): Input tensor values to sort. * ​output\_keys\_ (`LegacyUnsafePointer`): Output tensor values sorted in (default) descending order. * ​input\_key\_ids\_ (`LegacyUnsafePointer`): Input tensor indices. * ​output\_key\_ids\_ (`LegacyUnsafePointer`): Output tensor indices sorted in (default) descending order. * ​num\_keys ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of keys to sort per batch. * ​skip\_sort (`LegacyUnsafePointer`): Whether sorting is skipped for this batch.
--- ## run_radix_sort_pairs_gpu
`run_radix_sort_pairs_gpu[dtype: DType, out_idx_type: DType, ascending: Bool = False, BLOCK_SIZE: Int = 256, NUM_BITS_PER_PASS: Int = 4](ctx: DeviceContext, mut keys: DoubleBuffer[dtype], mut key_ids: DoubleBuffer[out_idx_type], skip_sort: LegacyUnsafePointer[Scalar[DType.bool]], in_shape: IndexList[size, element_type=element_type])`
--- ## top_p_sampling_gpu
`top_p_sampling_gpu[dtype: DType, out_idx_type: DType, //, _test_sort: Bool = False](ctx: DeviceContext, top_ps: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_logits: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], out_token_ids: LayoutTensor[out_idx_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], temperature: Scalar[dtype] = 1)` GPU implementation of Top-P sampling for token selection. This function applies temperature scaling, softmax, a radix sort, and then samples tokens based on the cumulative probability mass (Top-P).
--- ## topk_wrapper
`topk_wrapper[T: DType, out_idx_type: DType, is_top_p: Bool, largest: Bool = True, _test_sort: Bool = False](K: Int, num_elements: Int, num_blocks_per_input: Int, in_buffer: LegacyUnsafePointer[Scalar[T]], local_topk_vals: LegacyUnsafePointer[Scalar[T]], local_topk_idxs: LegacyUnsafePointer[Scalar[out_idx_type]], p_threshold: LegacyUnsafePointer[Scalar[T]], skip_sort: LegacyUnsafePointer[Scalar[DType.bool]])` Copy of `Kernels/mojo/nn/topk.mojo:_topk_stage1` with the addition of max\_vals and p\_threshold arguments to determine if sorting is needed for top-p/min-p sampling. Arguments: K: Int - Number of top elements to select per block num\_elements: Int - Size of last dimension of input buffer (vocab size) num\_blocks\_per\_input: Int - Number of blocks used to process the input data in\_buffer: UnsafePointer\[Scalar\[T]] - Input buffer containing the elements to process local\_topk\_vals: UnsafePointer\[Scalar\[T]] - Output buffer to store the local top-K values local\_topk\_idxs: UnsafePointer\[Scalar\[out\_idx\_type]] - Output buffer to store the indices of local top-K elements p\_threshold: UnsafePointer\[Scalar\[T]] - Threshold for top-p sampling if is\_top\_p is True else min-p coefficient skip\_sort: UnsafePointer\[Scalar\[DType.bool]] - Output buffer to store whether sorting is needed **Parameters:** * ​T ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - The data type of the elements. * ​out\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - The data type of the output indices. * ​is\_top\_p ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - Whether this if for top-p sampling or min-p sampling. * ​largest ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - Whether to find the maximum or minimum value. * ​\_test\_sort ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - An internal test flag to not skip sort if testing.
--- ## topp_minp_sampling_kernel
`topp_minp_sampling_kernel[dtype: DType, out_idx_type: DType, is_top_p: Bool](p_thresholds_: LegacyUnsafePointer[Scalar[dtype]], sorted_probs_: LegacyUnsafePointer[Scalar[dtype]], sorted_ids_: LegacyUnsafePointer[Scalar[out_idx_type]], out_token_ids: LegacyUnsafePointer[Scalar[out_idx_type]], skip_sort: LegacyUnsafePointer[Scalar[DType.bool]], vocab_size: Int)` Top P-Min P sampling kernel. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - scalar values dtype. * ​out\_idx\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType - output index type. * ​is\_top\_p ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Bool - Whether to use Top-P (True) or Min-P (False) sampling.
--- ## nvml
Implements wrappers around the NVIDIA Management Library (nvml). ## Modules * [​`nvml`](./nvml/): Implements wrappers around the NVIDIA Management Library (nvml).
--- ## ClockType
`@register_passable(trivial)` `struct ClockType` ## Fields * ​code (`Int32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `GRAPHICS` `comptime GRAPHICS = ClockType(0)` Graphics clock domain. ### `MEM` `comptime MEM = ClockType(2)` Memory clock domain. ### `SM` `comptime SM = ClockType(1)` SM clock domain. ### `VIDEO` `comptime VIDEO = ClockType(2)` Video clock domain. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** `Bool`
--- ## Device
`struct Device` ## Fields * ​idx (`Int`): * ​device (`_DeviceImpl`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self, idx: Int = 0)` ### `get_driver_version` `get_driver_version(self) -> DriverVersion` Returns NVIDIA driver version. **Returns:** `DriverVersion` ### `max_mem_clock` `max_mem_clock(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `max_graphics_clock` `max_graphics_clock(self) -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int) ### `mem_clocks` `mem_clocks(self) -> List[Int]` **Returns:** [`List`](/mojo/stdlib/collections/list/List) ### `graphics_clocks` `graphics_clocks(self, memory_clock_mhz: Int) -> List[Int]` **Returns:** [`List`](/mojo/stdlib/collections/list/List) ### `set_clock` `set_clock(self, mem_clock: Int, graphics_clock: Int)` ### `gpu_turbo_enabled` `gpu_turbo_enabled(self) -> Bool` Returns True if the gpu turbo is enabled. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `set_gpu_turbo` `set_gpu_turbo(self, enabled: Bool = True)` Sets the GPU turbo state. ### `get_persistence_mode` `get_persistence_mode(self) -> Bool` Returns True if the gpu persistence mode is enabled. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool) ### `set_persistence_mode` `set_persistence_mode(self, enabled: Bool = True)` Sets the persistence mode. ### `set_max_gpu_clocks` `set_max_gpu_clocks(device)` ### `__str__` `__str__(self) -> String` **Returns:** `String` ### `write_to` `write_to(self, mut writer: T)` ### `__repr__` `__repr__(self) -> String` **Returns:** `String`
--- ## DriverVersion
`struct DriverVersion` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`StringableRaising`](/mojo/stdlib/builtin/str/StringableRaising), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, var value: List[String])` ### `__copyinit__` `__copyinit__(out self, other: Self)` ### `major` `major(self) -> Int` **Returns:** `Int` ### `minor` `minor(self) -> Int` **Returns:** `Int` ### `patch` `patch(self) -> Int` **Returns:** `Int` ### `__str__` `__str__(self) -> String` **Returns:** `String`
--- ## EnableState
`@register_passable(trivial)` `struct EnableState` ## Fields * ​code (`Int32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `DISABLED` `comptime DISABLED = EnableState(0)` Feature disabled. ### `ENABLED` `comptime ENABLED = EnableState(1)` Feature enabled. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** `Bool`
--- ## Result
`@register_passable(trivial)` `struct Result` ## Fields * ​code (`Int32`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `ALREADY_INITIALIZED` `comptime ALREADY_INITIALIZED = Result(5)` Deprecated: Multiple initializations are now allowed through ref counting. ### `ARGUMENT_VERSION_MISMATCH` `comptime ARGUMENT_VERSION_MISMATCH = Result(25)` The provided version is invalid/unsupported. ### `CORRUPTED_INFOROM` `comptime CORRUPTED_INFOROM = Result(14)` The infoROM is corrupted. ### `DEPRECATED` `comptime DEPRECATED = Result(26)` The requested functionality has been deprecated. ### `DRIVER_NOT_LOADED` `comptime DRIVER_NOT_LOADED = Result(9)` NVIDIA driver is not loaded. ### `FREQ_NOT_SUPPORTED` `comptime FREQ_NOT_SUPPORTED = Result(24)` Ran out of critical resources, other than memory. ### `FUNCTION_NOT_FOUND` `comptime FUNCTION_NOT_FOUND = Result(13)` Local version of NVML doesn't implement this function. ### `GPU_IS_LOST` `comptime GPU_IS_LOST = Result(15)` The GPU has fallen off the bus or has otherwise become inaccessible. ### `GPU_NOT_FOUND` `comptime GPU_NOT_FOUND = Result(28)` No GPUs were found. ### `IN_USE` `comptime IN_USE = Result(19)` An operation cannot be performed because the GPU is currently in use. ### `INSUFFICIENT_POWER` `comptime INSUFFICIENT_POWER = Result(8)` A device's external power cables are not properly attached. ### `INSUFFICIENT_RESOURCES` `comptime INSUFFICIENT_RESOURCES = Result(23)` Ran out of critical resources, other than memory. ### `INSUFFICIENT_SIZE` `comptime INSUFFICIENT_SIZE = Result(7)` An input argument is not large enough. ### `INVALID_ARGUMENT` `comptime INVALID_ARGUMENT = Result(2)` A supplied argument is invalid. ### `IRQ_ISSUE` `comptime IRQ_ISSUE = Result(11)` NVIDIA Kernel detected an interrupt issue with a GPU. ### `LIB_RM_VERSION_MISMATCH` `comptime LIB_RM_VERSION_MISMATCH = Result(18)` RM detects a driver/library version mismatch. ### `LIBRARY_NOT_FOUND` `comptime LIBRARY_NOT_FOUND = Result(12)` NVML Shared Library couldn't be found or loaded. ### `MEMORY` `comptime MEMORY = Result(20)` Insufficient memory. ### `NO_DATA` `comptime NO_DATA = Result(21)` No data. ### `NO_PERMISSION` `comptime NO_PERMISSION = Result(4)` The current user does not have permission for operation. ### `NOT_FOUND` `comptime NOT_FOUND = Result(6)` A query to find an object was unsuccessful. ### `NOT_READY` `comptime NOT_READY = Result(27)` The system is not ready for the request. ### `NOT_SUPPORTED` `comptime NOT_SUPPORTED = Result(3)` The requested operation is not available on target device. ### `OPERATING_SYSTEM` `comptime OPERATING_SYSTEM = Result(17)` The GPU control device has been blocked by the operating system/cgroups. ### `RESET_REQUIRED` `comptime RESET_REQUIRED = Result(16)` The GPU requires a reset before it can be used again. ### `SUCCESS` `comptime SUCCESS = Result(0)` The operation was successful. ### `TIMEOUT` `comptime TIMEOUT = Result(10)` User provided timeout passed. ### `UNINITIALIZED` `comptime UNINITIALIZED = Result(1)` NVML was not first initialized with `nvmlInit()`. ### `UNKNOWN` `comptime UNKNOWN = Result(999)` An internal driver error occurred. ### `VGPU_ECC_NOT_SUPPORTED` `comptime VGPU_ECC_NOT_SUPPORTED = Result(22)` The requested vgpu operation is not available on target device, because ECC is enabled. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` **Returns:** `Bool` ### `write_to` `write_to(self, mut writer: T)` ### `__str__` `__str__(self) -> String` **Returns:** `String`
--- ## nvml (Nvml)
Implements wrappers around the NVIDIA Management Library (nvml). ## `comptime` values ### `CUDA_NVML_LIBRARY` `comptime CUDA_NVML_LIBRARY = _Global["CUDA_NVML_LIBRARY", _init_dylib]` ### `CUDA_NVML_LIBRARY_BASE_NAME` `comptime CUDA_NVML_LIBRARY_BASE_NAME = "libnvidia-ml"` ### `CUDA_NVML_LIBRARY_DIR` `comptime CUDA_NVML_LIBRARY_DIR = "/usr/lib/x86_64-linux-gnu"` ### `CUDA_NVML_LIBRARY_EXT` `comptime CUDA_NVML_LIBRARY_EXT = ".so"` ## Structs * [​`ClockType`](./ClockType): * [​`Device`](./Device): * [​`DriverVersion`](./DriverVersion): * [​`EnableState`](./EnableState): * [​`Result`](./Result):
--- ## quantization
This package contains a set of APIs for quantizing tensor data. Quantization is a technique used to reduce the precision of floating-point numbers, which are used in most neural networks. Quantization is a type of lossy compression, which means that some precision is lost, but the resulting tensors take less memory and computations are faster. ## Modules * [​`per_channel_grouped_4bit`](./per_channel_grouped_4bit/): * [​`qmatmul`](./qmatmul/): * [​`qmatmul_gpu`](./qmatmul_gpu/): * [​`qmatmul_k`](./qmatmul_k/):
--- ## Q4sym
`struct Q4sym[group_size: Int, float_dtype: DType = DType.float32]` Q4sym: compresses values of type `float_dtype` to 4bit unsigned integers which have been dynamically symmetrically quantized with the given scale factor. `group_size` determines the number of elements which share quantization parameters. We store things in a strided fashion: Example: Assume `group_size = 8` and we want to process uint4 numbers: A, B, C, D, E, F, G, H which have associated bits aaaa, bbbb, cccc, .... eeeeaaaa|ffffbbbb|ggggcccc|hhhhdddd To uncompress to floating point, take the decoded uint4 value, subtract the implicit zero-point of 2^4=8, and multiply by the scale factor. ## Parameters * ​group\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of encoded numbers stored in this struct. * ​float\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The floating point dtype this struct works with. ## Fields * ​scale (`StaticTuple[UInt8, 2]`): The FP16 scale of the group, stored as individual bytes. * ​bits (`StaticTuple[UInt8, (group_size // 2)]`): The bits of the encoded uint4 numbers. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ## Methods ### `__init__` `__init__(out self)` Construct a default initialized Q4sym. `__init__(out self, data: SIMD[float_dtype, group_size])` Construct an encoded Q4sym from data. **Args:** * ​data ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The floating point data to encode and store. ### `decode_scale` `decode_scale(mut self) -> Float16` Obtain the scale factor. **Returns:** `Float16`: The decoded scale factor. ### `decode_unsigned` `decode_unsigned(mut self) -> SIMD[DType.uint8, group_size]` Decode the stored uint4 numbers to uint8. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The decoded stored numbers as uint8 numbers. These have an implicit zero-point of 8. ### `decode_signed` `decode_signed(mut self) -> SIMD[DType.int8, group_size]` Decode the stored uint4 numbers to requantized int4 numbers. This is done by simply subtracting an implicit zp of 8 from the unsigned decoding. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The decoded stored numbers as int8 numbers. These have a zero-point of 0\. ### `decode_fully` `decode_fully(mut self) -> SIMD[float_dtype, group_size]` Decode the stored numbers into floating point representation. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The decoded numbers. ### `quantize_and_write_to_tensor` `static quantize_and_write_to_tensor(input_tensor: LayoutTensor[float_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_tensor: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], input_shape: IndexList[LayoutTensor[float_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank])` Encodes the floating point numbers in `input_tensor` along the inner-most dimension and writes the result to output\_tensor. **Args:** * ​input\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor we are encoding. * ​output\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor containing the encoded input. The shape of the output should be the same as the input except along the inner dimension where if the original inner dimension was `d`, the corresponding output dimension should be: ceil(`d` / group\_size) \* size\_of(self). * ​input\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the input tensor. ### `dequantize_and_write_to_tensor` `static dequantize_and_write_to_tensor(input_tensor: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_tensor: LayoutTensor[float_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_shape: IndexList[LayoutTensor[float_dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank])` Encodes the floating point numbers in `input_tensor` along the inner-most dimension and writes the result to output\_tensor. **Args:** * ​input\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The input tensor we are decoding. * ​output\_tensor ([`LayoutTensor`](/mojo/kernels/layout/layout_tensor/LayoutTensor)): The output tensor containing the decoded input. * ​output\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the output tensor.
--- ## block_Q4_K
`struct block_Q4_K` ## Fields * ​base\_scale (`Float16`): * ​base\_min (`Float16`): * ​q\_scales\_and\_mins (`InlineArray[UInt8, 12]`): * ​q\_bits (`InlineArray[UInt8, 128]`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `group_count` `comptime group_count = 8` ### `group_size` `comptime group_size = 32`
--- ## block_Q6_K
`struct block_Q6_K` ## Fields * ​q\_bits\_lo (`InlineArray[UInt8, 128]`): * ​q\_bits\_hi (`InlineArray[UInt8, 64]`): * ​q\_scales (`InlineArray[Int8, 16]`): * ​base\_scale (`Float16`): ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `group_count` `comptime group_count = 16` ### `group_size` `comptime group_size = 16`
--- ## block_QK_K
`struct block_QK_K` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `quantized_k` `comptime quantized_k = 256`
--- ## calculate_symmetric_vector
`calculate_symmetric_vector[input_dtype: DType, simd_width: Int, output_bits: Int](data: SIMD[input_dtype, simd_width]) -> Tuple[SIMD[DType.uint8, simd_width], Scalar[input_dtype]]` Symmetrically quantizes the given SIMD vector `data` with input type `input_dtype` and `simd_width` elements, assuming we want the results to fit in an unsigned integer of size `output_bits`. **Parameters:** * ​input\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input tensor. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the SIMD input. * ​output\_bits ([`Int`](/mojo/stdlib/builtin/int/Int)): The bits we want to fit the unsigned integral result in. **Args:** * ​data ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input SIMD we want to quantize. **Returns:** `Tuple`: A vector of the quantized values. The associated scale factor.
--- ## per_channel_grouped_4bit
## Structs * [​`block_Q4_K`](./block_Q4_K): * [​`block_Q6_K`](./block_Q6_K): * [​`block_QK_K`](./block_QK_K): * [​`Q4sym`](./Q4sym): Q4sym: compresses values of type `float_dtype` to 4bit unsigned integers which have been dynamically symmetrically quantized with the given scale factor. ## Functions * [​`calculate_symmetric_vector`](./calculate_symmetric_vector): Symmetrically quantizes the given SIMD vector `data` with input type `input_dtype` and `simd_width` elements, assuming we want the results to fit in an unsigned integer of size `output_bits`. * [​`q4_k_dequantize_impl`](./q4_k_dequantize_impl): * [​`q6_k_dequantize_impl`](./q6_k_dequantize_impl): * [​`scale_min_k4`](./scale_min_k4):
--- ## q4_k_dequantize_impl
`q4_k_dequantize_impl(input_tensor: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_tensor: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## q6_k_dequantize_impl
`q6_k_dequantize_impl(input_tensor: LayoutTensor[DType.uint8, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_tensor: LayoutTensor[DType.float32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], output_shape: IndexList[LayoutTensor[DType.float32, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment].rank])`
--- ## scale_min_k4
`scale_min_k4(src_ptr: LegacyUnsafePointer[block_Q4_K, mut=mut, origin=origin], g: Int) -> Tuple[Float32, Float32]` **Returns:** `Tuple`
--- ## qmatmul
## `comptime` values ### `K_BATCH_SIZE` `comptime K_BATCH_SIZE = 512` Defines the batch size of K used to pack A and unpack B weights. ## Functions * [​`matmul_qint4`](./matmul_qint4): * [​`matmul_qint4_pack_b`](./matmul_qint4_pack_b):
--- ## matmul_qint4
`matmul_qint4[group_size: Int, b_layout: Layout = Layout.row_major[2](), elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[DType.uint8, b_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## matmul_qint4_pack_b
`matmul_qint4_pack_b[group_size: Int](b: LayoutTensor[DType.uint8, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_rot: LayoutTensor[DType.uint8, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## args_to_tuple
`args_to_tuple[swap: Bool](arg_0: Int, arg_1: Int) -> Tuple[Int, Int]` **Returns:** `Tuple`
--- ## gpu_qint4_repack_GPTQ
`gpu_qint4_repack_GPTQ[group_size: Int, target: StringSlice[StaticConstantOrigin]](b: LayoutTensor[DType.uint8, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[DType.uint8, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], perm_idx: OptionalReg[LayoutTensor[DType.int32, Layout.row_major(-1), MutAnyOrigin]] = None, ctx: DeviceContextPtr = DeviceContextPtr())`
--- ## gpu_qint4_repack_Q4_0
`gpu_qint4_repack_Q4_0[b_shape: DimList, //, target: StringSlice[StaticConstantOrigin]](b: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr = DeviceContextPtr())`
--- ## qmatmul_gpu
## Functions * [​`args_to_tuple`](./args_to_tuple): * [​`gpu_qint4_repack_GPTQ`](./gpu_qint4_repack_GPTQ): * [​`gpu_qint4_repack_Q4_0`](./gpu_qint4_repack_Q4_0): * [​`matmul_gpu_qint4`](./matmul_gpu_qint4): * [​`matmul_gpu_qint4_impl`](./matmul_gpu_qint4_impl): * [​`multistage_gemm_q`](./multistage_gemm_q): * [​`multistage_mma_q`](./multistage_mma_q): * [​`multistage_qgemm_kernel`](./multistage_qgemm_kernel): * [​`pack_Q_tile`](./pack_Q_tile): * [​`q_smem_usage`](./q_smem_usage): * [​`repack_GPTQ_for_sm8x`](./repack_GPTQ_for_sm8x): * [​`repack_Q4_0_for_sm8x`](./repack_Q4_0_for_sm8x): * [​`unpack_4bit_int`](./unpack_4bit_int):
--- ## matmul_gpu_qint4
`matmul_gpu_qint4[c_type: DType, a_type: DType, //, group_size: Int, target: StringSlice[StaticConstantOrigin], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: DeviceContextPtr = DeviceContextPtr())`
--- ## matmul_gpu_qint4_impl
`matmul_gpu_qint4_impl[c_type: DType, a_type: DType, //, group_size: Int, target: StringSlice[StaticConstantOrigin], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], ctx: Optional[DeviceContext])`
--- ## multistage_gemm_q
`multistage_gemm_q[c_type: DType, a_type: DType, b_type: DType, //, *, group_size: Int, pack_factor: Int, config: MatmulConfig[a_type, b_type, c_type, True], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], a: LayoutTensor[a_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[b_type, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], runtime_config: MatmulConfig[a_type, b_type, c_type, True], ctx: DeviceContext)`
--- ## multistage_mma_q
`multistage_mma_q[BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, num_threads: Int, num_pipeline_stages: Int, transpose_b: Bool, group_size: Int, pack_factor: Int, c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, a_smem_layout: Layout, b_type: DType, b_layout: Layout, b_smem_layout: Layout, scales_type: DType, scales_layout: Layout, scales_smem_layout: Layout, /, *, swizzle_a: Bool = True, static_num_iters: Int = -1, prefetch_init: Bool = True, continue_prefetch_b: Bool = False, transpose_b_next: Bool = False, b_next_gmem_layout: Layout = Layout(), b_next_smem_layout: Layout = Layout(), next_op_b_iter_alignment: Int = align_of[b_type]()](c: LayoutTensor[c_type, c_layout, origin, address_space=AddressSpace.LOCAL], a_iter_arg: LayoutTensorIter[dtype, a_layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], b_iter_arg: LayoutTensorIter[b_type, b_layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], a_smem_iter_arg: LayoutTensorIter[a_type, a_smem_layout, origin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], mut b_smem_iter: LayoutTensorIter[b_type, b_smem_layout, origin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], scales_smem_iter_arg: LayoutTensorIter[scales_type, scales_smem_layout, origin, address_space=AddressSpace.SHARED, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], scales_iter_arg: LayoutTensorIter[scales_type, scales_layout, origin, address_space=address_space, alignment=alignment, circular=circular, axis=axis, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked], num_iters: Int, /, *, num_b_rows: OptionalReg[Int] = None)`
--- ## multistage_qgemm_kernel
`multistage_qgemm_kernel[c_type: DType, c_layout: Layout, a_type: DType, a_layout: Layout, b_packed_type: DType, b_layout: Layout, group_size: Int, pack_factor: Int, transpose_b: Bool, config: MatmulConfig[a_type, b_packed_type, c_type, transpose_b], elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](c: LayoutTensor[c_type, c_layout, MutAnyOrigin], a: LayoutTensor[a_type, a_layout, MutAnyOrigin], b_packed: LayoutTensor[b_packed_type, b_layout, MutAnyOrigin])`
--- ## pack_Q_tile
`pack_Q_tile(input: SIMD[DType.uint8, 16]) -> SIMD[DType.uint32, 4]` **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD)
--- ## q_smem_usage
`q_smem_usage[config: MatmulConfig[a_type, b_type, c_type, transpose_b], group_size: Int]() -> Int` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int)
--- ## repack_GPTQ_for_sm8x
`repack_GPTQ_for_sm8x[in_layout: Layout, out_layout: Layout, scales_type: DType, group_size: Int, has_perm: Bool, *, perm_layout: Layout = Layout()](in_tensor: LayoutTensor[DType.uint8, in_layout, MutAnyOrigin], out_tensor: LayoutTensor[DType.uint8, out_layout, MutAnyOrigin], perm_idx: LayoutTensor[DType.int32, perm_layout, MutAnyOrigin])`
--- ## repack_Q4_0_for_sm8x
`repack_Q4_0_for_sm8x[q_layout: Layout, repack_layout: Layout, scales_type: DType](q_weight: LayoutTensor[DType.uint8, q_layout, MutAnyOrigin], q_packed_weight: LayoutTensor[DType.uint8, repack_layout, MutAnyOrigin])`
--- ## unpack_4bit_int
`unpack_4bit_int(val: SIMD[DType.uint32, size], idx: Int) -> UInt8` **Returns:** `UInt8`
--- ## qmatmul_k
## Functions * [​`matmul_Q4_K`](./matmul_Q4_K): * [​`matmul_Q4_K_pack_b`](./matmul_Q4_K_pack_b): * [​`matmul_Q6_K`](./matmul_Q6_K): * [​`matmul_Q6_K_pack_b`](./matmul_Q6_K_pack_b):
--- ## matmul_Q4_K
`matmul_Q4_K[elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## matmul_Q4_K_pack_b
`matmul_Q4_K_pack_b(b: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## matmul_Q6_K
`matmul_Q6_K[elementwise_lambda_fn: OptionalReg[fn[dtype: DType, width: Int, *, alignment: Int = 1](IndexList[2], SIMD[dtype, width]) capturing -> None] = None](a: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c: LayoutTensor[DType.float32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## matmul_Q6_K_pack_b
`matmul_Q6_K_pack_b(b: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_packed: LayoutTensor[DType.uint8, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])`
--- ## elementwise
`elementwise[func: fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None, simd_width: Int, *, use_blocking_impl: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu", _trace_description: StringSlice[StaticConstantOrigin] = ""](shape: Int)` Executes `func[width, rank](indices)`, possibly as sub-tasks, for a suitable combination of width and indices so as to cover shape. Returns when all sub-tasks have completed. **Parameters:** * ​func (`fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None`): The body function. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD vector width to use. * ​use\_blocking\_impl ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Do not invoke the function using asynchronous calls. * ​target (`StringSlice`): The target to run on. * ​\_trace\_description (`StringSlice`): Description of the trace. **Args:** * ​shape ([`Int`](/mojo/stdlib/builtin/int/Int)): The shape of the buffer. **Raises:** If the operation fails. `elementwise[rank: Int, //, func: fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None, simd_width: Int, *, use_blocking_impl: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu", _trace_description: StringSlice[StaticConstantOrigin] = ""](shape: IndexList[rank, element_type=element_type])` Executes `func[width, rank](indices)`, possibly as sub-tasks, for a suitable combination of width and indices so as to cover shape. Returns when all sub-tasks have completed. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the buffer. * ​func (`fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None`): The body function. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD vector width to use. * ​use\_blocking\_impl ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Do not invoke the function using asynchronous calls. * ​target (`StringSlice`): The target to run on. * ​\_trace\_description (`StringSlice`): Description of the trace. **Args:** * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the buffer. **Raises:** If the operation fails. `elementwise[func: fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None, simd_width: Int, *, use_blocking_impl: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu", _trace_description: StringSlice[StaticConstantOrigin] = ""](shape: Int, context: DeviceContext)` Executes `func[width, rank](indices)`, possibly as sub-tasks, for a suitable combination of width and indices so as to cover shape. Returns when all sub-tasks have completed. **Parameters:** * ​func (`fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None`): The body function. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD vector width to use. * ​use\_blocking\_impl ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Do not invoke the function using asynchronous calls. * ​target (`StringSlice`): The target to run on. * ​\_trace\_description (`StringSlice`): Description of the trace. **Args:** * ​shape ([`Int`](/mojo/stdlib/builtin/int/Int)): The shape of the buffer. * ​context ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The device context to use. **Raises:** If the operation fails. `elementwise[rank: Int, //, func: fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None, simd_width: Int, *, use_blocking_impl: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu", _trace_description: StringSlice[StaticConstantOrigin] = ""](shape: IndexList[rank, element_type=element_type], context: DeviceContext)` Executes `func[width, rank](indices)`, possibly as sub-tasks, for a suitable combination of width and indices so as to cover shape. Returns when all sub-tasks have completed. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the buffer. * ​func (`fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None`): The body function. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD vector width to use. * ​use\_blocking\_impl ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Do not invoke the function using asynchronous calls. * ​target (`StringSlice`): The target to run on. * ​\_trace\_description (`StringSlice`): Description of the trace. **Args:** * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the buffer. * ​context ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The device context to use. **Raises:** If the operation fails. `elementwise[rank: Int, //, func: fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None, simd_width: Int, *, use_blocking_impl: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu", _trace_description: StringSlice[StaticConstantOrigin] = ""](shape: IndexList[rank, element_type=element_type], context: DeviceContextPtr)` Executes `func[width, rank](indices)`, possibly as sub-tasks, for a suitable combination of width and indices so as to cover shape. Returns when all sub-tasks have completed. **Parameters:** * ​rank ([`Int`](/mojo/stdlib/builtin/int/Int)): The rank of the buffer. * ​func (`fn[width: Int, rank: Int, alignment: Int = 1](IndexList[rank]) capturing -> None`): The body function. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD vector width to use. * ​use\_blocking\_impl ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Do not invoke the function using asynchronous calls. * ​target (`StringSlice`): The target to run on. * ​\_trace\_description (`StringSlice`): Description of the trace. **Args:** * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the buffer. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The device context to use. **Raises:** If the operation fails.
--- ## functional
Implements higher-order functions. You can import these APIs from the `algorithm` package. For example: ```mojo from algorithm import map ``` ## `comptime` values ### `BinaryTile1DTileUnitFunc` `comptime BinaryTile1DTileUnitFunc = fn[width: Int](Int, Int) capturing -> None` Signature of a tiled function that performs some work with a dynamic tile size and a secondary static tile size. ### `Dynamic1DTileUnitFunc` `comptime Dynamic1DTileUnitFunc = fn(Int, Int) capturing -> None` Signature of a 1D tiled function with dynamic tile size. The function takes a dynamic tile size and an offset argument, i.e. `func(offset: Int, tile_size: Int)`. ### `Dynamic1DTileUnswitchUnitFunc` `comptime Dynamic1DTileUnswitchUnitFunc = fn[sw: Bool](Int, Int, Int) capturing -> None` Signature of a dynamic tiled unswitch unit function. ### `Static1DTileUnitFunc` `comptime Static1DTileUnitFunc = fn[width: Int](Int) capturing -> None` Signature of a 1D tiled function with static tile size. The function takes a static tile size parameter and an offset argument, i.e. `func[tile_size: Int](offset: Int)`. ### `Static1DTileUnitFuncWithFlag` `comptime Static1DTileUnitFuncWithFlag = fn[width: Int, flag: Bool](Int) capturing -> None` Signature of a tiled function with a static tile size, offset, and flag. ### `Static1DTileUnitFuncWithFlags` `comptime Static1DTileUnitFuncWithFlags = fn[width: Int, left_flag: Bool, right_flag: Bool](Int) capturing -> None` Signature of a tiled function with left and right boundary flags. ### `Static1DTileUnswitchUnitFunc` `comptime Static1DTileUnswitchUnitFunc = fn[width: Int, sw: Bool](Int, Int) capturing -> None` Signature of a tiled function with static tile size and unswitch flag. The function takes a static tile size parameter and offset arguments, i.e. `func[tile_size: Int](offset: Int)`. ### `Static2DTileUnitFunc` `comptime Static2DTileUnitFunc = fn[tile_x: Int, tile_y: Int](Int, Int) capturing -> None` Signature of a 2D tiled function with static tile size. The function takes static tile size parameters and offset arguments, i.e. `func[tile_size_x: Int, tile_size_y: Int](offset_x: Int, offset_y: Int)`. ### `stencil` `comptime stencil = _stencil_impl_cpu` CPU implementation of stencil computation. ### `stencil_gpu` `comptime stencil_gpu = _stencil_impl_gpu` GPU implementation of stencil computation. ### `SwitchedFunction` `comptime SwitchedFunction = fn[sw: Bool]() raises capturing -> None` Signature of a function that unswitch can take. ### `SwitchedFunction2` `comptime SwitchedFunction2 = fn[sw0: Bool, sw1: Bool]() capturing -> None` Signature for unswitch supporting 2 predicates. ## Functions * [​`elementwise`](/mojo/stdlib/algorithm/functional/elementwise): Executes `func[width, rank](indices)`, possibly as sub-tasks, for a suitable combination of width and indices so as to cover shape. Returns when all sub-tasks have completed. * [​`map`](/mojo/stdlib/algorithm/functional/map): Maps a function over the integer range \[0, size). This lets you apply an integer index-based operation across data captured by the mapped function (for example, an indexed buffer). * [​`parallelize`](/mojo/stdlib/algorithm/functional/parallelize): Executes func(0) ... func(num\_work\_items-1) as sub-tasks in parallel, and returns when all are complete. * [​`parallelize_over_rows`](/mojo/stdlib/algorithm/functional/parallelize_over_rows): Parallelize func over non-axis dims of shape. * [​`sync_parallelize`](/mojo/stdlib/algorithm/functional/sync_parallelize): Executes func(0) ... func(num\_work\_items-1) as parallel sub-tasks, and returns when all are complete. * [​`tile`](/mojo/stdlib/algorithm/functional/tile): A generator that launches work groups in specified list of tile sizes. * [​`tile_and_unswitch`](/mojo/stdlib/algorithm/functional/tile_and_unswitch): Performs time and unswitch functional transformation. * [​`tile_middle_unswitch_boundaries`](/mojo/stdlib/algorithm/functional/tile_middle_unswitch_boundaries): Divides 1d iteration space into three parts and tiles them with different steps. * [​`unswitch`](/mojo/stdlib/algorithm/functional/unswitch): Performs a functional unswitch transformation. * [​`vectorize`](/mojo/stdlib/algorithm/functional/vectorize): Simplifies SIMD optimized loops by mapping a function across a range from 0 to `size`, incrementing by `simd_width` at each step. The remainder of `size % simd_width` will run in separate iterations.
--- ## map
`map[origins: OriginSet, //, func: fn(Int) capturing -> None](size: Int)` Maps a function over the integer range \[0, size). This lets you apply an integer index-based operation across data captured by the mapped function (for example, an indexed buffer). For example: ```mojo from algorithm import map def main(): # Create list with initial values to act on var list = List[Float32](1.0, 2.0, 3.0, 4.0, 5.0) # Function applied to the value at each index @parameter fn exponent_2(idx: Int): list[idx] = 2.0 ** list[idx] # Apply the mapped function across the index range map[exponent_2](len(list)) # Show results for idx in range(len(list)): print(list[idx]) ``` Example output: ```output 2.0 4.0 8.0 16.0 32.0 ``` :::note Don't confuse `algorithm.map` (this eager, index-based helper) with [`iter.map`](/mojo/stdlib/iter/map/), which returns a lazy iterator that applies a function to each element. ::: **Parameters:** * ​origins ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): Capture origins for mapped function. * ​func (`fn(Int) capturing -> None`): Parameterized function applied at each index. **Args:** * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements in the index range.
--- ## parallelize
`parallelize[origins: OriginSet, //, func: fn(Int) capturing -> None](num_work_items: Int)` Executes func(0) ... func(num\_work\_items-1) as sub-tasks in parallel, and returns when all are complete. **Parameters:** * ​origins ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): The capture origins. * ​func (`fn(Int) capturing -> None`): The function to invoke. **Args:** * ​num\_work\_items ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of parallel tasks. `parallelize[origins: OriginSet, //, func: fn(Int) capturing -> None](num_work_items: Int, num_workers: Int)` Executes func(0) ... func(num\_work\_items-1) as sub-tasks in parallel, and returns when all are complete. **Parameters:** * ​origins ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): The capture origins. * ​func (`fn(Int) capturing -> None`): The function to invoke. **Args:** * ​num\_work\_items ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of parallel tasks. * ​num\_workers ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of workers to use for execution.
--- ## parallelize_over_rows
`parallelize_over_rows[func: fn(Int, Int) capturing -> None](shape: IndexList[size, element_type=element_type], axis: Int, grain_size: Int)` Parallelize func over non-axis dims of shape. **Parameters:** * ​func (`fn(Int, Int) capturing -> None`): Function to call on range of rows. **Args:** * ​shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): Shape to parallelize over. * ​axis ([`Int`](/mojo/stdlib/builtin/int/Int)): Rows are slices along the axis dimension of shape. * ​grain\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The minimum number of elements to warrant using an additional thread.
--- ## sync_parallelize
`sync_parallelize[origins: OriginSet, //, func: fn(Int) raises capturing -> None](num_work_items: Int)` Executes func(0) ... func(num\_work\_items-1) as parallel sub-tasks, and returns when all are complete. TODO: Currently exceptions raised by func will cause a trap rather than be propagated back to the caller. **Parameters:** * ​origins ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): The capture origins. * ​func (`fn(Int) raises capturing -> None`): The function to invoke. **Args:** * ​num\_work\_items ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of parallel tasks.
--- ## tile (Functional)
`tile[workgroup_function: Static1DTileUnitFunc, tile_size_list: VariadicList[Int]](offset: Int, upperbound: Int)` A generator that launches work groups in specified list of tile sizes. A workgroup function is a function that can process a configurable consecutive "tile" of workload. E.g. `work_on[3](5)` should launch computation on item 5,6,7, and should be semantically equivalent to `work_on[1](5)`, `work_on[1](6)`, `work_on[1](7)`. This generator will try to proceed with the given list of tile sizes on the listed order. E.g. `tile[func, (3,2,1)](offset, upperbound)` will try to call `func[3]` starting from offset until remaining work is less than 3 from upperbound and then try `func[2]`, and then `func[1]`, etc. **Parameters:** * ​workgroup\_function ([`Static1DTileUnitFunc`](/mojo/stdlib/algorithm/functional/#static1dtileunitfunc)): Workgroup function that processes one tile of workload. * ​tile\_size\_list ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of tile sizes to launch work. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial index to start the work from. * ​upperbound ([`Int`](/mojo/stdlib/builtin/int/Int)): The runtime upperbound that the work function should not exceed. `tile[workgroup_function: Dynamic1DTileUnitFunc](offset: Int, upperbound: Int, tile_size_list: VariadicList[Int])` A generator that launches work groups in specified list of tile sizes. This is the version of tile generator for the case where work\_group function can take the tile size as a runtime value. **Parameters:** * ​workgroup\_function ([`Dynamic1DTileUnitFunc`](/mojo/stdlib/algorithm/functional/#dynamic1dtileunitfunc)): Workgroup function that processes one tile of workload. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial index to start the work from. * ​upperbound ([`Int`](/mojo/stdlib/builtin/int/Int)): The runtime upperbound that the work function should not exceed. * ​tile\_size\_list ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of tile sizes to launch work. `tile[secondary_tile_size_list: VariadicList[Int], secondary_cleanup_tile: Int, workgroup_function: BinaryTile1DTileUnitFunc](offset: Int, upperbound: Int, primary_tile_size_list: VariadicList[Int], primary_cleanup_tile: Int)` A generator that launches work groups in specified list of tile sizes until the sum of primary\_tile\_sizes has exceeded the upperbound. **Parameters:** * ​secondary\_tile\_size\_list ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of static tile sizes to launch work. * ​secondary\_cleanup\_tile ([`Int`](/mojo/stdlib/builtin/int/Int)): Last static tile to use when primary tile sizes don't fit exactly within the upperbound. * ​workgroup\_function ([`BinaryTile1DTileUnitFunc`](/mojo/stdlib/algorithm/functional/#binarytile1dtileunitfunc)): Workgroup function that processes one tile of workload. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial index to start the work from. * ​upperbound ([`Int`](/mojo/stdlib/builtin/int/Int)): The runtime upperbound that the work function should not exceed. * ​primary\_tile\_size\_list ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of dynamic tile sizes to launch work. * ​primary\_cleanup\_tile ([`Int`](/mojo/stdlib/builtin/int/Int)): Last dynamic tile to use when primary tile sizes don't fit exactly within the upperbound. `tile[workgroup_function: Static2DTileUnitFunc, tile_sizes_x: VariadicList[Int], tile_sizes_y: VariadicList[Int]](offset_x: Int, offset_y: Int, upperbound_x: Int, upperbound_y: Int)` Launches workgroup\_function using the largest tile sizes possible in each dimension, starting from the x and y offset, until the x and y upperbounds are reached. **Parameters:** * ​workgroup\_function ([`Static2DTileUnitFunc`](/mojo/stdlib/algorithm/functional/#static2dtileunitfunc)): Function that is invoked for each tile and offset. * ​tile\_sizes\_x ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of tile sizes to use for the first parameter of workgroup\_function. * ​tile\_sizes\_y ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of tile sizes to use for the second parameter of workgroup\_function. **Args:** * ​offset\_x ([`Int`](/mojo/stdlib/builtin/int/Int)): Initial x offset passed to workgroup\_function. * ​offset\_y ([`Int`](/mojo/stdlib/builtin/int/Int)): Initial y offset passed to workgroup\_function. * ​upperbound\_x ([`Int`](/mojo/stdlib/builtin/int/Int)): Max offset in x dimension passed to workgroup function. * ​upperbound\_y ([`Int`](/mojo/stdlib/builtin/int/Int)): Max offset in y dimension passed to workgroup function.
--- ## tile_and_unswitch
`tile_and_unswitch[workgroup_function: Static1DTileUnswitchUnitFunc, tile_size_list: VariadicList[Int]](offset: Int, upperbound: Int)` Performs time and unswitch functional transformation. A variant of static tile given a workgroup function that can be unswitched. This generator is a fused version of tile and unswitch, where the static unswitch is true throughout the "inner" portion of the workload and is false only on the residue tile. **Parameters:** * ​workgroup\_function ([`Static1DTileUnswitchUnitFunc`](/mojo/stdlib/algorithm/functional/#static1dtileunswitchunitfunc)): Workgroup function that processes one tile of workload. * ​tile\_size\_list ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of tile sizes to launch work. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial index to start the work from. * ​upperbound ([`Int`](/mojo/stdlib/builtin/int/Int)): The runtime upperbound that the work function should not exceed. `tile_and_unswitch[workgroup_function: Dynamic1DTileUnswitchUnitFunc](offset: Int, upperbound: Int, tile_size_list: VariadicList[Int])` Performs time and unswitch functional transformation. A variant of dynamic tile given a workgroup function that can be unswitched. This generator is a fused version of tile and unswitch, where the static unswitch is true throughout the "inner" portion of the workload and is false only on the residue tile. **Parameters:** * ​workgroup\_function ([`Dynamic1DTileUnswitchUnitFunc`](/mojo/stdlib/algorithm/functional/#dynamic1dtileunswitchunitfunc)): Workgroup function that processes one tile of workload. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial index to start the work from. * ​upperbound ([`Int`](/mojo/stdlib/builtin/int/Int)): The runtime upperbound that the work function should not exceed. * ​tile\_size\_list ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of tile sizes to launch work.
--- ## tile_middle_unswitch_boundaries
`tile_middle_unswitch_boundaries[work_fn: Static1DTileUnitFuncWithFlag, middle_tile_sizes: VariadicList[Int], left_tile_size: Int = 1, right_tile_size: Int = 1](left_boundary_start: Int, left_boundary_end: Int, right_boundary_start: Int, right_boundary_end: Int)` Divides 1d iteration space into three parts and tiles them with different steps. The 1d iteration space is divided into: 1\. \[left\_boundary\_start, left\_boundary\_end), effected by left boundary. 2\. \[left\_boundary\_end, right\_boundary\_start), not effected by any boundary. 3\. \[right\_boundary\_start, right\_boundary\_end), effected by right boundary. work\_fn's switch is true for the left and right boundaries, implying boundary conditions like padding in convolution. The middle part is tiled with static tile sizes with the switch as false. `middle_tile_sizes` should be in descending order for optimal performance. (Larger tile size appeared later in the list fails the while-loop.) **Parameters:** * ​work\_fn ([`Static1DTileUnitFuncWithFlag`](/mojo/stdlib/algorithm/functional/#static1dtileunitfuncwithflag)): Work function that processes one tile of workload. * ​middle\_tile\_sizes ([`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList)): List of tile sizes for the middle part. * ​left\_tile\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Tile size for the left boundary region. * ​right\_tile\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): Tile size for the right boundary region. **Args:** * ​left\_boundary\_start ([`Int`](/mojo/stdlib/builtin/int/Int)): Start index of the left boundary. * ​left\_boundary\_end ([`Int`](/mojo/stdlib/builtin/int/Int)): End index of the left boundary. * ​right\_boundary\_start ([`Int`](/mojo/stdlib/builtin/int/Int)): Start index of the right boundary. * ​right\_boundary\_end ([`Int`](/mojo/stdlib/builtin/int/Int)): End index of the right boundary. `tile_middle_unswitch_boundaries[work_fn: Static1DTileUnitFuncWithFlags, tile_size: Int, size: Int]()` Tile 1d iteration space with boundary conditions at both ends. This generator is primarily for convolution with static shapes. `work_fn`'s flags hints the function to handle padding at the boundary. The size is the static output row size, i.e., WO dimension. **Parameters:** * ​work\_fn ([`Static1DTileUnitFuncWithFlags`](/mojo/stdlib/algorithm/functional/#static1dtileunitfuncwithflags)): Work function that updates one tile. It has two flags for left and right boundaries, respectively. * ​tile\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): 1D Tile size. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): Iteration range is \[0, size).
--- ## unswitch
`unswitch[switched_func: SwitchedFunction](dynamic_switch: Bool)` Performs a functional unswitch transformation. Unswitch is a simple pattern that is similar idea to loop unswitching pass but extended to functional patterns. The pattern facilitates the following code transformation that reduces the number of branches in the generated code Before: ``` for i in range(...) if i < xxx: ... ``` After: ``` if i < ... for i in range(...) ... else for i in range(...) if i < xxx: ... ``` This unswitch function generalizes that pattern with the help of meta parameters and can be used to perform both loop unswitching and other tile predicate lifting like in simd and amx. TODO: Generalize to support multiple predicates. TODO: Once nested lambdas compose well should make unswitch compose with tile in an easy way. **Parameters:** * ​switched\_func ([`SwitchedFunction`](/mojo/stdlib/algorithm/functional/#switchedfunction)): The function containing the inner loop logic that can be unswitched. **Args:** * ​dynamic\_switch ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The dynamic condition that enables the unswitched code path. **Raises:** If the operation fails. `unswitch[switched_func: fn[sw: Bool]() capturing -> None](dynamic_switch: Bool)` Performs a functional unswitch transformation. Unswitch is a simple pattern that is similar idea to loop unswitching pass but extended to functional patterns. The pattern facilitates the following code transformation that reduces the number of branches in the generated code Before: ``` for i in range(...) if i < xxx: ... ``` After: ``` if i < ... for i in range(...) ... else for i in range(...) if i < xxx: ... ``` This unswitch function generalizes that pattern with the help of meta parameters and can be used to perform both loop unswitching and other tile predicate lifting like in simd and amx. TODO: Generalize to support multiple predicates. TODO: Once nested lambdas compose well should make unswitch compose with tile in an easy way. **Parameters:** * ​switched\_func (`fn[sw: Bool]() capturing -> None`): The function containing the inner loop logic that can be unswitched. **Args:** * ​dynamic\_switch ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The dynamic condition that enables the unswitched code path. `unswitch[switched_func: SwitchedFunction2](dynamic_switch_a: Bool, dynamic_switch_b: Bool)` Performs a functional 2-predicates unswitch transformation. **Parameters:** * ​switched\_func ([`SwitchedFunction2`](/mojo/stdlib/algorithm/functional/#switchedfunction2)): The function containing the inner loop logic that has 2 predicates which can be unswitched. **Args:** * ​dynamic\_switch\_a ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The first dynamic condition that enables the outer unswitched code path. * ​dynamic\_switch\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The second dynamic condition that enables the inner unswitched code path.
--- ## vectorize
`vectorize[func: fn, //, simd_width: Int, /, *, unroll_factor: Int = 1](size: Int, closure: func)` Simplifies SIMD optimized loops by mapping a function across a range from 0 to `size`, incrementing by `simd_width` at each step. The remainder of `size % simd_width` will run in separate iterations. The below example demonstrates how you could improve the performance of a loop, by setting multiple values at the same time using SIMD registers on the machine: ```mojo from algorithm.functional import vectorize from sys import simd_width_of # The amount of elements to loop through comptime size = 10 # How many Dtype.int32 elements fit into the SIMD register (4 on 128bit) comptime simd_width = simd_width_of[DType.int32]() # assumed to be 4 in this example fn main(): var p = alloc[Int32](size) fn closure[width: Int](i: Int) unified {mut}: print("storing", width, "els at pos", i) p.store[width=width](i, i) vectorize[simd_width](size, closure) print(p.load[width=simd_width]()) print(p.load[width=simd_width](simd_width)) ``` On a machine with a SIMD register size of 128, this will set 4xInt32 values on each iteration. The remainder of 10 % 4 is 2, so those last two elements will be set in two separate iterations: ```plaintext storing 4 els at pos 0 storing 4 els at pos 4 storing 1 els at pos 8 storing 1 els at pos 9 [0, 0, 0, 0, 4, 4, 4, 4, 8, 9] ``` You can also unroll the loop to potentially improve performance at the cost of binary size: ``` vectorize[closure, width, unroll_factor=2](size) ``` In the generated assembly the function calls will be repeated, resulting in fewer arithmetic, comparison, and conditional jump operations. The assembly would look like this in pseudocode: ``` closure[4](0) closure[4](4) # Remainder loop won't unroll unless `size` is passed as a parameter for i in range(8, 10): closure[1](i) closure[1](i) ``` You can pass `size` as a parameter if it's compile time known to reduce the iterations for the remainder. This only occurs if the remainder is an exponent of 2 (2, 4, 8, 16, ...). The remainder loop will still unroll for performance improvements if not an exponent of 2. **Parameters:** * ​func (`fn`): The function that will be called in the loop body. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD vector width. * ​unroll\_factor ([`Int`](/mojo/stdlib/builtin/int/Int)): The unroll factor for the main loop (Default 1). **Args:** * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The upper limit for the loop. * ​closure (`func`): The captured state of the function bound to func. `vectorize[func: fn, //, simd_width: Int, /, *, size: Int, unroll_factor: Int = size if is_gpu() else 1](closure: func)` Simplifies SIMD optimized loops by mapping a function across a range from 0 to `size`, incrementing by `simd_width` at each step. The remainder of `size % simd_width` will run in a single iteration if it's an exponent of 2. The below example demonstrates how you could improve the performance of a loop, by setting multiple values at the same time using SIMD registers on the machine: ```mojo from algorithm.functional import vectorize from sys import simd_width_of # The amount of elements to loop through comptime size = 10 # How many Dtype.int32 elements fit into the SIMD register (4 on 128bit) comptime simd_width = simd_width_of[DType.int32]() # assumed to be 4 in this example fn main(): var p = UnsafePointer[Int32].alloc(size) # The closure can capture the `p` pointer with unified {mut} fn closure[width: Int](i: Int) unified {mut}: print("storing", width, "els at pos", i) p.store[width=width](i, i) vectorize[simd_width](size, closure) print(p.load[width=simd_width]()) print(p.load[width=simd_width](simd_width)) ``` On a machine with a SIMD register size of 128, this will set 4xInt32 values on each iteration. The remainder of 10 % 4 is 2, so those last two elements will be set in a single iteration: ```plaintext storing 4 els at pos 0 storing 4 els at pos 4 storing 2 els at pos 8 [0, 0, 0, 0, 4, 4, 4, 4, 8, 8] ``` If the remainder is not an exponent of 2 (2, 4, 8, 16 ...) there will be a separate iteration for each element. However passing `size` as a parameter also allows the loop for the remaining elements to be unrolled. You can also unroll the main loop to potentially improve performance at the cost of binary size: ``` vectorize[width, size=size, unroll_factor=2](closure) ``` In the generated assembly the function calls will be repeated, resulting in fewer arithmetic, comparison, and conditional jump operations. The assembly would look like this in pseudocode: ``` closure[4](0) closure[4](4) closure[2](8) ``` **Parameters:** * ​func (`fn`): The function that will be called in the loop body. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD vector width. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The upper limit for the loop. * ​unroll\_factor ([`Int`](/mojo/stdlib/builtin/int/Int)): The unroll factor for the main loop (Default 1). **Args:** * ​closure (`func`): The captured state of the function bound to func.
--- ## algorithm
Implements the algorithm package. ## Modules * [​`functional`](/mojo/stdlib/algorithm/functional/): Implements higher-order functions. * [​`memory`](/mojo/stdlib/algorithm/memory/): Implements `parallel_memcpy`. * [​`reduction`](/mojo/stdlib/algorithm/reduction/): Implements SIMD reductions.
--- ## memory
Implements `parallel_memcpy`. You can import these APIs from the `algorithm` package. For example: ```mojo from algorithm import parallel_memcpy ``` ## Functions * [​`parallel_memcpy`](/mojo/stdlib/algorithm/memory/parallel_memcpy): Copies `count` elements from a memory buffer `src` to `dest` in parallel by spawning `num_tasks` tasks each copying `count_per_task` elements.
--- ## parallel_memcpy
`parallel_memcpy[dtype: DType](*, dest: UnsafePointer[Scalar[dtype], origin], src: UnsafePointer[Scalar[dtype], origin], count: Int, count_per_task: Int, num_tasks: Int)` Copies `count` elements from a memory buffer `src` to `dest` in parallel by spawning `num_tasks` tasks each copying `count_per_task` elements. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The element dtype. **Args:** * ​dest ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): The destination buffer. * ​src ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): The source buffer. * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements in the buffer. * ​count\_per\_task ([`Int`](/mojo/stdlib/builtin/int/Int)): Task size. * ​num\_tasks ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of tasks to run in parallel. `parallel_memcpy[dtype: DType](*, dest: UnsafePointer[Scalar[dtype], origin], src: UnsafePointer[Scalar[dtype], origin], count: Int)` Copies `count` elements from a memory buffer `src` to `dest` in parallel. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The element dtype. **Args:** * ​dest ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): The destination pointer. * ​src ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): The source pointer. * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to copy.
--- ## cumsum (Reduction)
`cumsum[dtype: DType](dst: Span[Scalar[dtype], origin], src: Span[Scalar[dtype], origin])` Computes the cumulative sum of all elements in a buffer. dst\[i] = src\[i] + src\[i-1] + ... + src\[0]. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​dst ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer that stores the result of cumulative sum operation. * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer of elements for which the cumulative sum is computed.
--- ## reduction
Implements SIMD reductions. You can import these APIs from the `algorithm` package. For example: ```mojo from algorithm import map_reduce ``` ## Functions * [​`cumsum`](/mojo/stdlib/algorithm/reduction/cumsum): Computes the cumulative sum of all elements in a buffer. dst\[i] = src\[i] + src\[i-1] + ... + src\[0]. * [​`map_reduce`](/mojo/stdlib/algorithm/reduction/map_reduce): Stores the result of calling input\_gen\_fn in dst and simultaneously reduce the result using a custom reduction function. * [​`max`](/mojo/stdlib/algorithm/reduction/max): Computes the max element in a buffer. * [​`mean`](/mojo/stdlib/algorithm/reduction/mean): Computes the mean value of the elements in a buffer. * [​`min`](/mojo/stdlib/algorithm/reduction/min): Computes the min element in a buffer. * [​`product`](/mojo/stdlib/algorithm/reduction/product): Computes the product of the buffer elements. * [​`reduce`](/mojo/stdlib/algorithm/reduction/reduce): Computes a custom reduction of buffer elements. * [​`reduce_boolean`](/mojo/stdlib/algorithm/reduction/reduce_boolean): Computes a bool reduction of buffer elements. The reduction will early exit if the `continue_fn` returns False. * [​`sum`](/mojo/stdlib/algorithm/reduction/sum): Computes the sum of buffer elements. * [​`variance`](/mojo/stdlib/algorithm/reduction/variance): Given a mean, computes the variance of elements in a buffer.
--- ## map_reduce
`map_reduce[simd_width: Int, dtype: DType, acc_type: DType, origins_gen: OriginSet, input_gen_fn: fn[dtype: DType, width: Int](Int) capturing -> SIMD[dtype, width], origins_vec: OriginSet, reduce_vec_to_vec_fn: fn[acc_type: DType, dtype: DType, width: Int](SIMD[acc_type, width], SIMD[dtype, width]) capturing -> SIMD[acc_type, width], reduce_vec_to_scalar_fn: fn[dtype: DType, width: Int](SIMD[dtype, width]) -> Scalar[dtype]](dst: Span[Scalar[dtype], origin], init: Scalar[acc_type]) -> Scalar[acc_type]` Stores the result of calling input\_gen\_fn in dst and simultaneously reduce the result using a custom reduction function. **Parameters:** * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The vector width for the computation. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The buffer elements dtype. * ​acc\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the reduction accumulator. * ​origins\_gen ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): The OriginSet of captured arguments by the input\_gen\_fn. * ​input\_gen\_fn (`fn[dtype: DType, width: Int](Int) capturing -> SIMD[dtype, width]`): A function that generates inputs to reduce. * ​origins\_vec ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): The OriginSet of captured arguments by the reduce\_vec\_to\_vec\_fn. * ​reduce\_vec\_to\_vec\_fn (`fn[acc_type: DType, dtype: DType, width: Int](SIMD[acc_type, width], SIMD[dtype, width]) capturing -> SIMD[acc_type, width]`): A mapping function. This function is used to combine (accumulate) two chunks of input data: e.g. we load two `8xfloat32` vectors of elements and need to reduce them into a single `8xfloat32` vector. * ​reduce\_vec\_to\_scalar\_fn (`fn[dtype: DType, width: Int](SIMD[dtype, width]) -> Scalar[dtype]`): A reduction function. This function is used to reduce a vector to a scalar. E.g. when we got `8xfloat32` vector and want to reduce it to an `float32` scalar. **Args:** * ​dst ([`Span`](/mojo/stdlib/memory/span/Span)): The output buffer. * ​init ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The initial value to use in accumulator. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The computed reduction value. `map_reduce[simd_width: Int, dtype: DType, acc_type: DType, origins_gen: OriginSet, input_gen_fn: fn[dtype: DType, width: Int](Int) capturing -> SIMD[dtype, width], origins_vec: OriginSet, reduce_vec_to_vec_fn: fn[acc_type: DType, dtype: DType, width: Int](SIMD[acc_type, width], SIMD[dtype, width]) capturing -> SIMD[acc_type, width], reduce_vec_to_scalar_fn: fn[dtype: DType, width: Int](SIMD[dtype, width]) -> Scalar[dtype], output_fn: fn[dtype_: DType, width: Int, alignment: Int](idx: Int, val: SIMD[dtype_, width]) capturing -> None](length: Int, init: Scalar[acc_type]) -> Scalar[acc_type]` Performs a vectorized map-reduce operation over a sequence. **Parameters:** * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD vector width to use. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the input elements. * ​acc\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the accumulator. * ​origins\_gen ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): Origin set for the input generation function. * ​input\_gen\_fn (`fn[dtype: DType, width: Int](Int) capturing -> SIMD[dtype, width]`): Function that generates input values at each index. * ​origins\_vec ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): Origin set for the reduction function. * ​reduce\_vec\_to\_vec\_fn (`fn[acc_type: DType, dtype: DType, width: Int](SIMD[acc_type, width], SIMD[dtype, width]) capturing -> SIMD[acc_type, width]`): Function that reduces a vector into the accumulator. * ​reduce\_vec\_to\_scalar\_fn (`fn[dtype: DType, width: Int](SIMD[dtype, width]) -> Scalar[dtype]`): Function that reduces a final vector to a scalar. * ​output\_fn (`fn[dtype_: DType, width: Int, alignment: Int](idx: Int, val: SIMD[dtype_, width]) capturing -> None`): Function to output intermediate results. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to process. * ​init ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The initial accumulator value. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The final reduced scalar value.
--- ## max (Reduction)
`max[dtype: DType](src: Span[Scalar[dtype], origin]) -> Scalar[dtype]` Computes the max element in a buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The maximum of the buffer elements. **Raises:** If the operation fails. `max[dtype: DType, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, /, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](input_shape: IndexList[size], reduce_dim: Int, context: DeviceContextPtr = DeviceContextPtr())` Computes the max across the input and output shape. This performs the max computation on the domain specified by `input_shape`, loading the inputs using the `input_fn`. The results are stored using the `output_fn`. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input and output. * ​input\_fn (`fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to load the input. * ​output\_fn (`fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None`): The function to store the output. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​target (`StringSlice`): The target to run on. **Args:** * ​input\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The input shape. * ​reduce\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to perform the max on. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The pointer to DeviceContext. **Raises:** If the operation fails.
--- ## mean (Reduction)
`mean[dtype: DType](src: Span[Scalar[dtype], origin]) -> Scalar[dtype]` Computes the mean value of the elements in a buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer of elements for which the mean is computed. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The mean value of the elements in the given buffer. **Raises:** If the operation fails. `mean[dtype: DType, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, /, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](input_shape: IndexList[size], reduce_dim: Int, output_shape: IndexList[size], context: DeviceContextPtr = DeviceContextPtr())` Computes the mean across the input and output shape. This performs the mean computation on the domain specified by `input_shape`, loading the inputs using the `input_fn`. The results' domain is `output_shape` which are stored using the `output_fn`. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input and output. * ​input\_fn (`fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to load the input. * ​output\_fn (`fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None`): The function to store the output. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​target (`StringSlice`): The target to run on. **Args:** * ​input\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The input shape. * ​reduce\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to perform the mean on. * ​output\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The output shape. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The pointer to DeviceContext. **Raises:** If the operation fails. `mean[dtype: DType, input_fn_1d: fn[dtype_: DType, width: Int](idx: Int) capturing -> SIMD[dtype_, width]](length: Int) -> Scalar[dtype]` Computes the arithmetic mean of values generated by a function. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements. * ​input\_fn\_1d (`fn[dtype_: DType, width: Int](idx: Int) capturing -> SIMD[dtype_, width]`): A function that generates SIMD values at each index. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to average. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The mean value. For integral types, uses integer division. **Raises:** To comply with how generators are used in this module.
--- ## min
`min[dtype: DType](src: Span[Scalar[dtype], origin]) -> Scalar[dtype]` Computes the min element in a buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The minimum of the buffer elements. **Raises:** If the operation fails. `min[dtype: DType, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, /, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](input_shape: IndexList[size], reduce_dim: Int, context: DeviceContextPtr = DeviceContextPtr())` Computes the min across the input and output shape. This performs the min computation on the domain specified by `input_shape`, loading the inputs using the `input_fn`. The results are stored using the `output_fn`. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input and output. * ​input\_fn (`fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to load the input. * ​output\_fn (`fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None`): The function to store the output. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​target (`StringSlice`): The target to run on. **Args:** * ​input\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The input shape. * ​reduce\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to perform the min on. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The pointer to DeviceContext. **Raises:** If the operation fails.
--- ## product (Reduction)
`product[dtype: DType](src: Span[Scalar[dtype], origin]) -> Scalar[dtype]` Computes the product of the buffer elements. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The product of the buffer elements. **Raises:** If the operation fails. `product[dtype: DType, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, /, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](input_shape: IndexList[size], reduce_dim: Int, context: DeviceContextPtr = DeviceContextPtr())` Computes the product across the input and output shape. This performs the product computation on the domain specified by `input_shape`, loading the inputs using the `input_fn`. The results are stored using the `output_fn`. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input and output. * ​input\_fn (`fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to load the input. * ​output\_fn (`fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None`): The function to store the output. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​target (`StringSlice`): The target to run on. **Args:** * ​input\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The input shape. * ​reduce\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to perform the product on. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The pointer to DeviceContext. **Raises:** If the operation fails.
--- ## reduce (Reduction)
`reduce[reduce_fn: fn[acc_type: DType, dtype: DType, width: Int](SIMD[acc_type, width], SIMD[dtype, width]) capturing -> SIMD[acc_type, width], dtype: DType](src: Span[Scalar[dtype], origin], init: Scalar[dtype]) -> Scalar[dtype]` Computes a custom reduction of buffer elements. **Parameters:** * ​reduce\_fn (`fn[acc_type: DType, dtype: DType, width: Int](SIMD[acc_type, width], SIMD[dtype, width]) capturing -> SIMD[acc_type, width]`): The lambda implementing the reduction. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The input buffer. * ​init ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The initial value to use in accumulator. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The computed reduction value. **Raises:** If the operation fails.
--- ## reduce_boolean
`reduce_boolean[reduce_fn: fn[dtype: DType, width: Int](SIMD[dtype, width]) capturing -> Bool, continue_fn: fn(Bool) capturing -> Bool, dtype: DType](src: Span[Scalar[dtype], origin], init: Bool) -> Bool` Computes a bool reduction of buffer elements. The reduction will early exit if the `continue_fn` returns False. **Parameters:** * ​reduce\_fn (`fn[dtype: DType, width: Int](SIMD[dtype, width]) capturing -> Bool`): A boolean reduction function. This function is used to reduce a vector to a scalar. E.g. when we got `8xfloat32` vector and want to reduce it to a `bool`. * ​continue\_fn (`fn(Bool) capturing -> Bool`): A function to indicate whether we want to continue processing the rest of the iterations. This takes the result of the reduce\_fn and returns True to continue processing and False to early exit. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The input buffer. * ​init ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The initial value to use. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): The computed reduction value.
--- ## sum (Reduction)
`sum[dtype: DType](src: Span[Scalar[dtype], origin]) -> Scalar[dtype]` Computes the sum of buffer elements. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The sum of the buffer elements. **Raises:** If the operation fails. `sum[dtype: DType, input_fn: fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width], output_fn: fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None, /, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](input_shape: IndexList[size], reduce_dim: Int, context: DeviceContextPtr = DeviceContextPtr())` Computes the sum across the input and output shape. This performs the sum computation on the domain specified by `input_shape`, loading the inputs using the `input_fn`. The results are stored using the `output_fn`. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input and output. * ​input\_fn (`fn[width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, width]`): The function to load the input. * ​output\_fn (`fn[width: Int, rank: Int](IndexList[rank], SIMD[dtype, width]) capturing -> None`): The function to store the output. * ​single\_thread\_blocking\_override ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, then the operation is run synchronously using a single thread. * ​target (`StringSlice`): The target to run on. **Args:** * ​input\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The input shape. * ​reduce\_dim ([`Int`](/mojo/stdlib/builtin/int/Int)): The axis to perform the sum on. * ​context ([`DeviceContextPtr`](/mojo/stdlib/runtime/asyncrt/DeviceContextPtr)): The pointer to DeviceContext. **Raises:** If the operation fails. `sum[dtype: DType, input_fn_1d: fn[dtype_: DType, width: Int](idx: Int) capturing -> SIMD[dtype_, width]](length: Int) -> Scalar[dtype]` Computes the sum of a 1D array using a provided input function. This function performs a reduction (sum) over a 1-dimensional array of the specified length and data type. The input values are provided by the `input_fn_1d` function, which takes an index and returns a SIMD vector of the specified width and data type. The reduction is performed using a single thread for deterministic results. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements to sum. * ​input\_fn\_1d (`fn[dtype_: DType, width: Int](idx: Int) capturing -> SIMD[dtype_, width]`): A function that takes a data type, SIMD width, and index, and returns a SIMD vector of input values. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements in the 1D array. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The sum of all elements as a scalar of the specified data type. **Raises:** Any exception raised by the input function or reduction process.
--- ## variance (Reduction)
`variance[dtype: DType](src: Span[Scalar[dtype], origin], mean_value: Scalar[dtype], correction: Int = 1) -> Scalar[dtype]` Given a mean, computes the variance of elements in a buffer. The mean value is used to avoid a second pass over the data: ``` variance(x) = sum((x - E(x))^2) / (size - correction) ``` **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer. * ​mean\_value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The mean value of the buffer. * ​correction ([`Int`](/mojo/stdlib/builtin/int/Int)): Normalize variance by size - correction. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The variance value of the elements in a buffer. **Raises:** If the operation fails. `variance[dtype: DType, input_fn_1d: fn[dtype_: DType, width: Int](idx: Int) capturing -> SIMD[dtype_, width]](length: Int, mean_value: Scalar[dtype], correction: Int = 1) -> Scalar[dtype]` Computes the variance of values generated by a function. Variance is calculated as: $$ \operatorname{variance}(X) = \frac{ \sum_{i=0}^{length-1} (X_i - \operatorname{E}(X_i))^2}{size - correction} $$ where `E` represents the deviation of a sample from the mean. This version takes the mean value as an argument to avoid a second pass over the data. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements. * ​input\_fn\_1d (`fn[dtype_: DType, width: Int](idx: Int) capturing -> SIMD[dtype_, width]`): A function that generates SIMD values at each index. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements. * ​mean\_value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The pre-computed mean value. * ​correction ([`Int`](/mojo/stdlib/builtin/int/Int)): Normalize variance by size - correction (default: 1 for sample variance). **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The variance value. **Raises:** If length is less than or equal to correction. `variance[dtype: DType](src: Span[Scalar[dtype], origin], correction: Int = 1) -> Scalar[dtype]` Computes the variance value of the elements in a buffer. ``` variance(x) = sum((x - E(x))^2) / (size - correction) ``` **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The dtype of the input. **Args:** * ​src ([`Span`](/mojo/stdlib/memory/span/Span)): The buffer. * ​correction ([`Int`](/mojo/stdlib/builtin/int/Int)): Normalize variance by size - correction (Default=1). **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The variance value of the elements in a buffer. **Raises:** If the operation fails. `variance[dtype: DType, input_fn_1d: fn[dtype_: DType, width: Int](idx: Int) capturing -> SIMD[dtype_, width]](length: Int, correction: Int = 1) -> Scalar[dtype]` Computes the variance of values generated by a function. This version computes the mean automatically in a first pass. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the elements. * ​input\_fn\_1d (`fn[dtype_: DType, width: Int](idx: Int) capturing -> SIMD[dtype_, width]`): A function that generates SIMD values at each index. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements. * ​correction ([`Int`](/mojo/stdlib/builtin/int/Int)): Normalize variance by size - correction (default: 1 for sample variance). **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The variance value. **Raises:** If length is less than or equal to correction.
--- ## b16decode
`b16decode(str: StringSlice[origin]) -> String` Performs base16 decoding on the input string. **Args:** * ​str ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A base16 encoded string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The decoded string.
--- ## b16encode
`b16encode(str: StringSlice[origin]) -> String` Performs base16 encoding on the input string slice. **Args:** * ​str ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The input string slice. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): Base16 encoding of the input string.
--- ## b64decode
`b64decode[*, validate: Bool = False](str: StringSlice[origin]) -> String` Performs base64 decoding on the input string. **Parameters:** * ​validate ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If true, the function will validate the input string. **Args:** * ​str ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A base64 encoded string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The decoded string. **Raises:** If the operation fails.
--- ## b64encode
`b64encode(input_bytes: Span[Byte, origin], mut result: String)` Performs base64 encoding on the input string. Notes: This method reserves the necessary capacity. `result` can be a 0 capacity string. **Args:** * ​input\_bytes ([`Span`](/mojo/stdlib/memory/span/Span)): The input string buffer. * ​result ([`String`](/mojo/stdlib/collections/string/string/String)): The string in which to store the values. `b64encode(input_string: StringSlice[origin]) -> String` Performs base64 encoding on the input string. **Args:** * ​input\_string ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The input string buffer. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The ASCII base64 encoded string. `b64encode(input_bytes: Span[Byte, origin]) -> String` Performs base64 encoding on the input string. **Args:** * ​input\_bytes ([`Span`](/mojo/stdlib/memory/span/Span)): The input string buffer. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The ASCII base64 encoded string.
--- ## base64
Provides functions for base64 encoding strings. You can import these APIs from the `base64` package. For example: ```mojo from base64 import b64encode ``` ## Functions * [​`b16decode`](/mojo/stdlib/base64/base64/b16decode): Performs base16 decoding on the input string. * [​`b16encode`](/mojo/stdlib/base64/base64/b16encode): Performs base16 encoding on the input string slice. * [​`b64decode`](/mojo/stdlib/base64/base64/b64decode): Performs base64 decoding on the input string. * [​`b64encode`](/mojo/stdlib/base64/base64/b64encode): Performs base64 encoding on the input string.
--- ## base64 (Base64)
Implements the base64 package. ## Modules * [​`base64`](/mojo/stdlib/base64/base64/): Provides functions for base64 encoding strings.
--- ## Bench
`struct Bench` Constructs a Benchmark object, used for running multiple benchmarks and comparing the results. Example: ```mojo from benchmark import ( Bench, BenchConfig, Bencher, BenchId, ThroughputMeasure, BenchMetric, Format, ) from utils import IndexList from gpu.host import DeviceContext from pathlib import Path fn example_kernel(): print("example_kernel") var shape = IndexList[2](1024, 1024) var bench = Bench(BenchConfig(max_iters=100)) @parameter @always_inline fn example(mut b: Bencher, shape: IndexList[2]) capturing raises: @parameter @always_inline fn kernel_launch(ctx: DeviceContext) raises: ctx.enqueue_function_checked[example_kernel, example_kernel]( grid_dim=shape[0], block_dim=shape[1] ) var bench_ctx = DeviceContext() b.iter_custom[kernel_launch](bench_ctx) bench.bench_with_input[IndexList[2], example]( BenchId("top_k_custom", "gpu"), shape, [ ThroughputMeasure( BenchMetric.elements, shape.flattened_length() ), ThroughputMeasure( BenchMetric.flops, shape.flattened_length() * 3 # number of ops ), ] ) # Add more benchmarks like above to compare results # Pretty print in table format print(bench) # Dump report to csv file bench.config.out_file = Path("out.csv") bench.dump_report() # Print in tabular csv format bench.config.format = Format.tabular print(bench) ``` You can pass arguments when running a program that makes use of `Bench`: ```sh mojo benchmark.mojo -o out.csv -r 10 ``` This will repeat the benchmarks 10 times and write the output to `out.csv` in csv format. ## Fields * ​config (`BenchConfig`): Constructs a Benchmark object based on specific configuration and mode. * ​mode (`Mode`): Benchmark mode object representing benchmark or test mode. * ​info\_vec (`List[BenchmarkInfo]`): A list containing the benchmark info. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ## Methods ### `__init__` `__init__(out self, config: Optional[BenchConfig] = None, mode: Mode = Mode.Benchmark)` Constructs a Benchmark object based on specific configuration and mode. **Args:** * ​config ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Benchmark configuration object to control length and frequency of benchmarks. * ​mode ([`Mode`](/mojo/stdlib/benchmark/bencher/Mode)): Benchmark mode object representing benchmark or test mode. **Raises:** If the operation fails. ### `check_mpirun` `check_mpirun(mut self) -> Int` Check environment to examine whether the benchmark is called via mpirun. If so, use pe\_rank=OMPI\_COMM\_WORLD\_RANK as a suffix for output file. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer representing pe rank (default=-1). **Raises:** If the operation fails. ### `append_output_suffix` `append_output_suffix(mut self, suffix: String)` Append a suffix string to output file name. **Args:** * ​suffix ([`String`](/mojo/stdlib/collections/string/string/String)): Suffix string to append to output file name. ### `bench_with_input` `bench_with_input[T: AnyType, bench_fn: fn(mut Bencher, T) raises capturing -> None](mut self, bench_id: BenchId, input: T, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmarks an input function with input args of type AnyType. **Parameters:** * ​T ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Benchmark function input type. * ​bench\_fn (`fn(mut Bencher, T) raises capturing -> None`): The function to be benchmarked. **Args:** * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​input (`T`): Represents the target function's input arguments. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `bench_with_input[T: AnyTrivialRegType, bench_fn: fn(mut Bencher, T) raises capturing -> None](mut self, bench_id: BenchId, input: T, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmarks an input function with input args of type AnyTrivialRegType. **Parameters:** * ​T ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Benchmark function input type. * ​bench\_fn (`fn(mut Bencher, T) raises capturing -> None`): The function to be benchmarked. **Args:** * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​input (`T`): Represents the target function's input arguments. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. ### `bench_function` `bench_function[bench_fn: fn() raises capturing -> None](mut self, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure](), fixed_iterations: Optional[Int] = None)` Benchmarks or Tests an input function. **Parameters:** * ​bench\_fn (`fn() raises capturing -> None`): The function to be benchmarked. **Args:** * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. * ​fixed\_iterations ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Just run a fixed number of iterations. **Raises:** If the operation fails. `bench_function[bench_fn: fn() capturing -> None](mut self, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure](), fixed_iterations: Optional[Int] = None)` Benchmarks or Tests an input function. **Parameters:** * ​bench\_fn (`fn() capturing -> None`): The function to be benchmarked. **Args:** * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. * ​fixed\_iterations ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Just run a fixed number of iterations. **Raises:** If the operation fails. `bench_function[bench_fn: fn(mut Bencher) raises capturing -> None](mut self, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure](), fixed_iterations: Optional[Int] = None)` Benchmarks or Tests an input function. **Parameters:** * ​bench\_fn (`fn(mut Bencher) raises capturing -> None`): The function to be benchmarked. **Args:** * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. * ​fixed\_iterations ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Just run a fixed number of iterations. **Raises:** If the operation fails. ### `dump_report` `dump_report(mut self)` Prints out the report from a Benchmark execution. If `Bench.config.out_file` is set, it will also write the output in the format set in `out_file_format` to the file defined in `out_file`. **Raises:** If the operation fails. ### `pad` `pad[pad_str: StringSlice[StaticConstantOrigin] = " "](self, width: Int, string: String) -> String` Pads a string to a given width. **Parameters:** * ​pad\_str ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The length 1 string to use for the padding. **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width to pad the string to. * ​string ([`String`](/mojo/stdlib/collections/string/string/String)): The string to pad. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string padded to the given width. ### `__str__` `__str__(self) -> String` Returns a string representation of the benchmark results. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representing the benchmark results. ### `write_to` `write_to(self, mut writer: T)` Writes the benchmark results to a writer. **Args:** * ​writer (`T`): The writer to write to.
--- ## BenchConfig
`struct BenchConfig` Defines a benchmark configuration struct to control execution times and frequency. ## Fields * ​out\_file (`Optional[Path]`): Output file to write results to. * ​min\_runtime\_secs (`Float64`): Lower bound on benchmarking time in secs. * ​max\_runtime\_secs (`Float64`): Upper bound on benchmarking time in secs. * ​num\_warmup\_iters (`Int`): Number of warmup iterations. * ​max\_batch\_size (`Int`): The maximum number of iterations to perform per time measurement. * ​max\_iters (`Int`): Max number of iterations to run. * ​num\_repetitions (`Int`): Number of times the benchmark has to be repeated. * ​flush\_denormals (`Bool`): Whether or not the denormal values are flushed. * ​show\_progress (`Bool`): If True, print progress of each benchmark. * ​format (`Format`): The format to print results. (default: "table"). * ​out\_file\_format (`Format`): The format to write out the file with `dump_file` (default: "csv"). * ​verbose\_timing (`Bool`): Whether to print verbose timing results. * ​verbose\_metric\_names (`Bool`): If True print the metric name and unit, else print the unit only. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ### `VERBOSE_TIMING_LABELS` `comptime VERBOSE_TIMING_LABELS = List[String]("min (ms)", "mean (ms)", "max (ms)", "duration (ms)", Tuple[]())` Labels to print verbose timing results. ## Methods ### `__init__` `__init__(out self, out_file: Optional[Path] = None, min_runtime_secs: Float64 = 0, max_runtime_secs: Float64 = 1, num_warmup_iters: Int = 10, max_batch_size: Int = 0, max_iters: Int = 1000, num_repetitions: Int = 1, flush_denormals: Bool = True)` Constructs and initializes Benchmark config object with default and inputted values. **Args:** * ​out\_file ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Output file to write results to. * ​min\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Lower bound on benchmarking time in secs (default `0.0`). * ​max\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Upper bound on benchmarking time in secs (default `1.0`). * ​num\_warmup\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of warmup iterations (default `10`). * ​max\_batch\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum number of iterations to perform per time measurement. * ​max\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Max number of iterations to run (default `1_000`). * ​num\_repetitions ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of times the benchmark has to be repeated. * ​flush\_denormals ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether or not the denormal values are flushed. **Raises:** If the operation fails.
--- ## BenchId
`struct BenchId` Defines a benchmark Id struct to identify and represent a particular benchmark execution. ## Fields * ​func\_name (`String`): The target function name. * ​input\_id (`Optional[String]`): The target function input id phrase. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ## Methods ### `__init__` `__init__(out self, func_name: String, input_id: String)` Constructs a Benchmark Id object from input function name and Id phrase. **Args:** * ​func\_name ([`String`](/mojo/stdlib/collections/string/string/String)): The target function name. * ​input\_id ([`String`](/mojo/stdlib/collections/string/string/String)): The target function input id phrase. `__init__(out self, func_name: String)` Constructs a Benchmark Id object from input function name. **Args:** * ​func\_name ([`String`](/mojo/stdlib/collections/string/string/String)): The target function name. `__init__(out self, func_name: StringLiteral[value])` Constructs a Benchmark Id object from input function name. **Args:** * ​func\_name ([`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral)): The target function name.
--- ## BenchMetric
`struct BenchMetric` Defines a benchmark throughput metric. ## Fields * ​code (`Int`): Op-code of the Metric. * ​name (`String`): Metric's name. * ​unit (`String`): Metric's throughput rate unit (count/second). ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ### `bytes` `comptime bytes = BenchMetric(1, "DataMovement", "GB/s")` Metric for measuring data movement in bytes per second. ### `DEFAULTS` `comptime DEFAULTS = List[BenchMetric](BenchMetric.elements, BenchMetric.bytes, BenchMetric.flops, Tuple[]())` Default set of benchmark metrics. ### `elements` `comptime elements = BenchMetric(0, "throughput", "GElems/s")` Metric for measuring throughput in elements per second. ### `flops` `comptime flops = BenchMetric(2, "Arithmetic", "GFLOPS/s")` Metric for measuring floating point operations per second. ### `theoretical_flops` `comptime theoretical_flops = BenchMetric(3, "TheoreticalArithmetic", "GFLOPS/s")` Metric for measuring theoretical floating point operations per second. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` Compares two metrics for equality. **Args:** * ​other (`Self`): The metric to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the two metrics are equal. ### `__ne__` `__ne__(self, other: Self) -> Bool` Compares two metrics for inequality. **Args:** * ​other (`Self`): The metric to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the two metrics are NOT equal. ### `__str__` `__str__(self) -> String` Gets a string representation of this metric. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation. ### `write_to` `write_to(self, mut writer: T)` Formats this BenchMetric to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `check_name` `check_name(self, alt_name: String) -> Bool` Checks whether a string contains the metric's name. **Args:** * ​alt\_name ([`String`](/mojo/stdlib/collections/string/string/String)): Alternative name of a metric. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if 'alt\_name' is valid alternative of the metric's name. ### `get_metric_from_list` `static get_metric_from_list(name: String, metric_list: List[BenchMetric]) -> Self` Gets a metric from a given list using only the metric's name. **Args:** * ​name ([`String`](/mojo/stdlib/collections/string/string/String)): Metric's name. * ​metric\_list ([`List`](/mojo/stdlib/collections/list/List)): List of metrics to search. **Returns:** `Self`: The selected metric. **Raises:** If the operation fails.
--- ## Bencher
`@register_passable` `struct Bencher` Defines a Bencher struct which facilitates the timing of a target function. ## Fields * ​num\_iters (`Int`): Number of iterations to run the target function. * ​elapsed (`Int`): The total time elapsed when running the target function. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(num_iters: Int) -> Self` Constructs a Bencher object to run and time a function. **Args:** * ​num\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of times to run the target function. ### `iter` `iter[iter_fn: fn() capturing -> None](mut self)` Returns the total elapsed time by running a target function a particular number of times. **Parameters:** * ​iter\_fn (`fn() capturing -> None`): The target function to benchmark. `iter[iter_fn: fn() raises capturing -> None](mut self)` Returns the total elapsed time by running a target function a particular number of times. **Parameters:** * ​iter\_fn (`fn() raises capturing -> None`): The target function to benchmark. **Raises:** If the operation fails. ### `iter_preproc` `iter_preproc[iter_fn: fn() capturing -> None, preproc_fn: fn() capturing -> None](mut self)` Returns the total elapsed time by running a target function a particular number of times. **Parameters:** * ​iter\_fn (`fn() capturing -> None`): The target function to benchmark. * ​preproc\_fn (`fn() capturing -> None`): The function to preprocess the target function. ### `iter_custom` `iter_custom[iter_fn: fn(Int) raises capturing -> Int](mut self)` Times a target function with custom number of iterations. **Parameters:** * ​iter\_fn (`fn(Int) raises capturing -> Int`): The target function to benchmark. `iter_custom[kernel_launch_fn: fn(DeviceContext) raises capturing -> None](mut self, ctx: DeviceContext)` Times a target GPU function with custom number of iterations via DeviceContext ctx. **Parameters:** * ​kernel\_launch\_fn (`fn(DeviceContext) raises capturing -> None`): The target GPU kernel launch function to benchmark. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The GPU DeviceContext for launching kernel. `iter_custom[kernel_launch_fn: fn(DeviceContext, Int) raises capturing -> None](mut self, ctx: DeviceContext)` Times a target GPU function with custom number of iterations via DeviceContext ctx. **Parameters:** * ​kernel\_launch\_fn (`fn(DeviceContext, Int) raises capturing -> None`): The target GPU kernel launch function to benchmark. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The GPU DeviceContext for launching kernel. ### `iter_custom_multicontext` `iter_custom_multicontext[kernel_launch_fn: fn() raises capturing -> None](mut self, ctxs: List[DeviceContext])` Times a target GPU function with custom number of iterations via DeviceContext ctx. **Parameters:** * ​kernel\_launch\_fn (`fn() raises capturing -> None`): The target GPU kernel launch function to benchmark. **Args:** * ​ctxs ([`List`](/mojo/stdlib/collections/list/List)): The list of GPU DeviceContext's for launching kernel.
--- ## BenchmarkInfo
`struct BenchmarkInfo` Defines a Benchmark Info struct to record execution Statistics. ## Fields * ​name (`String`): The name of the benchmark. * ​result (`Report`): The output report after executing a benchmark. * ​measures (`List[ThroughputMeasure]`): Optional arg used to represent a list of ThroughputMeasure's. * ​verbose\_timing (`Bool`): Whether to print verbose timing results. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `__init__(out self, name: String, var result: Report, var measures: List[ThroughputMeasure] = List[ThroughputMeasure](), verbose_timing: Bool = False)` Constructs a `BenchmarkInfo` object to return benchmark report and statistics. **Args:** * ​name ([`String`](/mojo/stdlib/collections/string/string/String)): The name of the benchmark. * ​result ([`Report`](/mojo/stdlib/benchmark/benchmark/Report)): The output report after executing a benchmark. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. * ​verbose\_timing ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to print verbose timing results.
--- ## Format
`struct Format` Defines a format for the benchmark output when printing or writing to a file. ## Fields * ​value (`StaticString`): The format to print results. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `csv` `comptime csv = Format("csv")` Comma separated values with no alignment. ### `table` `comptime table = Format("table")` Table format with dynamically aligned columns. ### `tabular` `comptime tabular = Format("tabular")` Comma separated values with dynamically aligned columns. ## Methods ### `__init__` `__init__(out self, value: StringSlice[origin])` Constructs a Format object from a string. **Args:** * ​value ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The format to print results. ### `__eq__` `__eq__(self, other: Self) -> Bool` Checks if two Format objects are equal. **Args:** * ​other (`Self`): The `Format` to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the two `Format` objects are equal, false otherwise. ### `__str__` `__str__(self) -> String` Returns the string representation of the format. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the format. ### `write_to` `write_to(self, mut writer: T)` Writes the format to a writer. **Args:** * ​writer (`T`): The writer to write the `Format` to.
--- ## Mode
`struct Mode` Defines a Benchmark Mode to distinguish between test runs and actual benchmarks. ## Fields * ​value (`Int`): Represents the mode type. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Benchmark` `comptime Benchmark = Mode(0)` Mode for running actual benchmarks. ### `Test` `comptime Test = Mode(1)` Mode for running tests. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` Check if its Benchmark mode or test mode. **Args:** * ​other (`Self`): The mode to be compared against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): If its a test mode or benchmark mode.
--- ## ThroughputMeasure
`struct ThroughputMeasure` Records a throughput metric of metric BenchMetric and value. ## Fields * ​metric (`BenchMetric`): Type of throughput metric. * ​value (`Int`): Measured count of throughput metric. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `__init__(out self, name: String, value: Int, reference: List[BenchMetric] = BenchMetric.DEFAULTS)` Creates a `ThroughputMeasure` based on metric's name. Example: For the default bench metrics `BenchMetric.DEFAULTS` the following are equivalent: \- `ThroughputMeasure(BenchMetric.fmas, 1024)` \- `ThroughputMeasure("fmas", 1024)` \- `ThroughputMeasure("fmas", 1024, BenchMetric.DEFAULTS)` **Args:** * ​name ([`String`](/mojo/stdlib/collections/string/string/String)): The name of BenchMetric in its corresponding reference. * ​value ([`Int`](/mojo/stdlib/builtin/int/Int)): The measured value to assign to this metric. * ​reference ([`List`](/mojo/stdlib/collections/list/List)): List of BenchMetrics that contains this metric. **Raises:** If the operation fails. ### `__str__` `__str__(self) -> String` Gets a string representation of this `ThroughputMeasure`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation. ### `write_to` `write_to(self, mut writer: T)` Formats this ThroughputMeasure to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `compute` `compute(self, elapsed_sec: Float64) -> Float64` Computes throughput rate for this metric per unit of time (second). **Args:** * ​elapsed\_sec ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Elapsed time measured in seconds. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The throughput values as a floating point 64.
--- ## bencher (Bencher)
## Structs * [​`Bench`](/mojo/stdlib/benchmark/bencher/Bench): Constructs a Benchmark object, used for running multiple benchmarks and comparing the results. * [​`BenchConfig`](/mojo/stdlib/benchmark/bencher/BenchConfig): Defines a benchmark configuration struct to control execution times and frequency. * [​`Bencher`](/mojo/stdlib/benchmark/bencher/Bencher): Defines a Bencher struct which facilitates the timing of a target function. * [​`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId): Defines a benchmark Id struct to identify and represent a particular benchmark execution. * [​`BenchmarkInfo`](/mojo/stdlib/benchmark/bencher/BenchmarkInfo): Defines a Benchmark Info struct to record execution Statistics. * [​`BenchMetric`](/mojo/stdlib/benchmark/bencher/BenchMetric): Defines a benchmark throughput metric. * [​`Format`](/mojo/stdlib/benchmark/bencher/Format): Defines a format for the benchmark output when printing or writing to a file. * [​`Mode`](/mojo/stdlib/benchmark/bencher/Mode): Defines a Benchmark Mode to distinguish between test runs and actual benchmarks. * [​`ThroughputMeasure`](/mojo/stdlib/benchmark/bencher/ThroughputMeasure): Records a throughput metric of metric BenchMetric and value.
--- ## Batch
`@register_passable(trivial)` `struct Batch` A batch of benchmarks, the benchmark.run() function works out how many iterations to run in each batch based the how long the previous iterations took. ## Fields * ​duration (`Int`): Total duration of batch stored as nanoseconds. * ​iterations (`Int`): Total iterations in the batch. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `mean` `mean(self, unit: String = Unit.s) -> Float64` Returns the average duration of the batch. **Args:** * ​unit ([`String`](/mojo/stdlib/collections/string/string/String)): The time unit to display for example: ns, ms, s (default `s`). **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The average duration of the batch.
--- ## Report
`struct Report` Contains the average execution time, iterations, min and max of each batch. ## Fields * ​warmup\_duration (`Int`): The total duration it took to warmup. * ​runs (`List[Batch]`): A `List` of benchmark runs. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self)` Default initializer for the Report. Sets all values to 0 ### `iters` `iters(self) -> Int` The total benchmark iterations. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total benchmark iterations. ### `duration` `duration(self, unit: String = Unit.s) -> Float64` The total duration it took to run all benchmarks. **Args:** * ​unit ([`String`](/mojo/stdlib/collections/string/string/String)): The time unit to display for example: ns, ms, s (default `s`). **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The total duration it took to run all benchmarks. ### `mean` `mean(self, unit: String = Unit.s) -> Float64` The average duration of all benchmark runs. **Args:** * ​unit ([`String`](/mojo/stdlib/collections/string/string/String)): The time unit to display for example: ns, ms, s (default `s`). **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The average duration of all benchmark runs. ### `min` `min(self, unit: String = Unit.s) -> Float64` The batch of benchmarks that was the fastest to run. **Args:** * ​unit ([`String`](/mojo/stdlib/collections/string/string/String)): The time unit to display for example: ns, ms, s (default `s`). **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The fastest duration out of all batches. ### `max` `max(self, unit: String = Unit.s) -> Float64` The batch of benchmarks that was the slowest to run. **Args:** * ​unit ([`String`](/mojo/stdlib/collections/string/string/String)): The time unit to display for example: ns, ms, s (default `s`). **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The slowest duration out of all batches. ### `print` `print(self, unit: String = Unit.s)` Prints out the shortened version of the report. **Args:** * ​unit ([`String`](/mojo/stdlib/collections/string/string/String)): The time unit to display for example: ns, ms, s (default `s`). ### `print_full` `print_full(self, unit: String = Unit.s)` Prints out the full version of the report with each batch of benchmark runs. **Args:** * ​unit ([`String`](/mojo/stdlib/collections/string/string/String)): The time unit to display for example: ns, ms, s (default `s`).
--- ## Unit
`struct Unit` Time Unit used by Benchmark Report. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `ms` `comptime ms = "ms"` Milliseconds. ### `ns` `comptime ns = "ns"` Nanoseconds. ### `s` `comptime s = "s"` Seconds.
--- ## benchmark
Implements the benchmark module for runtime benchmarking. You can import these APIs from the `benchmark` package. For example: ```mojo import benchmark from time import sleep ``` You can pass any `fn` as a parameter into `benchmark.run[...]()`, it will return a `Report` where you can get the mean, duration, max, and more: ```mojo fn sleeper(): sleep(.01) var report = benchmark.run[sleeper]() print(report.mean()) ``` ```output 0.012256487394957985 ``` You can print a full report: ```mojo report.print() ``` ```output --------------------- Benchmark Report (s) --------------------- Mean: 0.012265747899159664 Total: 1.459624 Iters: 119 Warmup Total: 0.025020000000000001 Fastest Mean: 0.0121578 Slowest Mean: 0.012321428571428572 ``` Or all the batch runs: ```mojo report.print_full() ``` ```output --------------------- Benchmark Report (s) --------------------- Mean: 0.012368649122807017 Total: 1.410026 Iters: 114 Warmup Total: 0.023341000000000001 Fastest Mean: 0.012295586956521738 Slowest Mean: 0.012508099999999999 Batch: 1 Iterations: 20 Mean: 0.012508099999999999 Duration: 0.250162 Batch: 2 Iterations: 46 Mean: 0.012295586956521738 Duration: 0.56559700000000002 Batch: 3 Iterations: 48 Mean: 0.012380562499999999 Duration: 0.59426699999999999 ``` If you want to use a different time unit you can bring in the Unit and pass it in as an argument: ```mojo from benchmark import Unit report.print(Unit.ms) ``` ```output --------------------- Benchmark Report (ms) --------------------- Mean: 0.012312411764705882 Total: 1.465177 Iters: 119 Warmup Total: 0.025010999999999999 Fastest Mean: 0.012015649999999999 Slowest Mean: 0.012421204081632654 ``` The unit's are just aliases for string constants, so you can for example: ```mojo print(report.mean("ms")) ``` ```output 12.199145299145298 ``` Benchmark.run takes four arguments to change the behaviour, to set warmup iterations to 5: ```mojo r = benchmark.run[sleeper](5) ``` ```output 0.012004808080808081 ``` To set 1 warmup iteration, 2 max iterations, a min total time of 3 sec, and a max total time of 4 s: ```mojo r = benchmark.run[sleeper](1, 2, 3, 4) ``` Note that benchmarking continues until `min_runtime_secs` has elapsed and either `max_runtime_secs` OR `max_iters` is achieved. ## Structs * [​`Batch`](/mojo/stdlib/benchmark/benchmark/Batch): A batch of benchmarks, the benchmark.run() function works out how many iterations to run in each batch based the how long the previous iterations took. * [​`Report`](/mojo/stdlib/benchmark/benchmark/Report): Contains the average execution time, iterations, min and max of each batch. * [​`Unit`](/mojo/stdlib/benchmark/benchmark/Unit): Time Unit used by Benchmark Report. ## Functions * [​`run`](/mojo/stdlib/benchmark/benchmark/run): Benchmarks the function passed in as a parameter.
--- ## run
`run[func: fn() raises -> None](num_warmup_iters: Int = 1, max_iters: Int = 1000000000, min_runtime_secs: Float64 = 0.10000000000000001, max_runtime_secs: Float64 = 60, max_batch_size: Int = 0) -> Report` Benchmarks the function passed in as a parameter. Benchmarking continues until 'min\_runtime\_secs' has elapsed and either `max_runtime_secs` OR `max_iters` is achieved. **Parameters:** * ​func (`fn() raises -> None`): The function to benchmark. **Args:** * ​num\_warmup\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of warmup iterations. * ​max\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Max number of iterations to run (default `1_000_000_000`). * ​min\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Lower bound on benchmarking time in secs (default `0.1`). * ​max\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Upper bound on benchmarking time in secs (default `60`). * ​max\_batch\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum number of iterations to perform per time measurement. **Returns:** [`Report`](/mojo/stdlib/benchmark/benchmark/Report): Average execution time of func in ns. **Raises:** If the operation fails. `run[func: fn() -> None](num_warmup_iters: Int = 1, max_iters: Int = 1000000000, min_runtime_secs: Float64 = 0.10000000000000001, max_runtime_secs: Float64 = 60, max_batch_size: Int = 0) -> Report` Benchmarks the function passed in as a parameter. Benchmarking continues until 'min\_runtime\_secs' has elapsed and either `max_runtime_secs` OR `max_iters` is achieved. **Parameters:** * ​func (`fn() -> None`): The function to benchmark. **Args:** * ​num\_warmup\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of warmup iterations. * ​max\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Max number of iterations to run (default `1_000_000_000`). * ​min\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Lower bound on benchmarking time in secs (default `0.1`). * ​max\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Upper bound on benchmarking time in secs (default `60`). * ​max\_batch\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum number of iterations to perform per time measurement. **Returns:** [`Report`](/mojo/stdlib/benchmark/benchmark/Report): Average execution time of func in ns. **Raises:** If the operation fails. `run[func: fn() raises capturing -> None](num_warmup_iters: Int = 1, max_iters: Int = 1000000000, min_runtime_secs: Float64 = 0.10000000000000001, max_runtime_secs: Float64 = 60, max_batch_size: Int = 0) -> Report` Benchmarks the function passed in as a parameter. Benchmarking continues until 'min\_runtime\_secs' has elapsed and either `max_runtime_secs` OR `max_iters` is achieved. **Parameters:** * ​func (`fn() raises capturing -> None`): The function to benchmark. **Args:** * ​num\_warmup\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of warmup iterations. * ​max\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Max number of iterations to run (default `1_000_000_000`). * ​min\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Lower bound on benchmarking time in secs (default `0.1`). * ​max\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Upper bound on benchmarking time in secs (default `60`). * ​max\_batch\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum number of iterations to perform per time measurement. **Returns:** [`Report`](/mojo/stdlib/benchmark/benchmark/Report): Average execution time of func in ns. **Raises:** If the operation fails. `run[func: fn() capturing -> None](num_warmup_iters: Int = 1, max_iters: Int = 1000000000, min_runtime_secs: Float64 = 0.10000000000000001, max_runtime_secs: Float64 = 60, max_batch_size: Int = 0) -> Report` Benchmarks the function passed in as a parameter. Benchmarking continues until 'min\_runtime\_secs' has elapsed and either `max_runtime_secs` OR `max_iters` is achieved. **Parameters:** * ​func (`fn() capturing -> None`): The function to benchmark. **Args:** * ​num\_warmup\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of warmup iterations. * ​max\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): Max number of iterations to run (default `1_000_000_000`). * ​min\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Lower bound on benchmarking time in secs (default `0.1`). * ​max\_runtime\_secs ([`Float64`](/mojo/stdlib/builtin/simd/#float64)): Upper bound on benchmarking time in secs (default `60`). * ​max\_batch\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum number of iterations to perform per time measurement. **Returns:** [`Report`](/mojo/stdlib/benchmark/benchmark/Report): Average execution time of func in ns. **Raises:** If the operation fails.
--- ## compiler
## Functions * [​`keep`](/mojo/stdlib/benchmark/compiler/keep): Provides a hint to the compiler to not optimize the variable use away.
--- ## keep
`keep[Type: AnyType, origin: ImmutOrigin, //](ref [origin] val: Type)` Provides a hint to the compiler to not optimize the variable use away. This is useful in benchmarking to avoid the compiler not deleting the code to be benchmarked because the variable is not used in a side-effecting manner. **Parameters:** * ​Type ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The type of the input. * ​origin ([`ImmutOrigin`](/mojo/stdlib/builtin/type_aliases/#immutorigin)): The origin of the input. **Args:** * ​val (`Type`): The value to not optimize away.
--- ## benchmark (Benchmark)
Implements the benchmark package for runtime benchmarking. You can import these APIs from the `benchmark` package. For example: ```mojo import benchmark from time import sleep ``` You can pass any `fn` as a parameter into `benchmark.run[...]()`, it will return a `Report` where you can get the mean, duration, max, and more: ```mojo fn sleeper(): sleep(.01) var report = benchmark.run[sleeper]() print(report.mean()) ``` ```output 0.012256487394957985 ``` You can print a full report: ```mojo report.print() ``` ```output --------------------- Benchmark Report (s) --------------------- Mean: 0.012265747899159664 Total: 1.459624 Iters: 119 Warmup Mean: 0.01251 Warmup Total: 0.025020000000000001 Warmup Iters: 2 Fastest Mean: 0.0121578 Slowest Mean: 0.012321428571428572 ``` Or all the batch runs: ```mojo report.print_full() ``` ```output --------------------- Benchmark Report (s) --------------------- Mean: 0.012368649122807017 Total: 1.410026 Iters: 114 Warmup Mean: 0.0116705 Warmup Total: 0.023341000000000001 Warmup Iters: 2 Fastest Mean: 0.012295586956521738 Slowest Mean: 0.012508099999999999 Batch: 1 Iterations: 20 Mean: 0.012508099999999999 Duration: 0.250162 Batch: 2 Iterations: 46 Mean: 0.012295586956521738 Duration: 0.56559700000000002 Batch: 3 Iterations: 48 Mean: 0.012380562499999999 Duration: 0.59426699999999999 ``` If you want to use a different time unit you can bring in the Unit and pass it in as an argument: ```mojo from benchmark import Unit report.print(Unit.ms) ``` ```output --------------------- Benchmark Report (ms) --------------------- Mean: 0.012312411764705882 Total: 1.465177 Iters: 119 Warmup Mean: 0.012505499999999999 Warmup Total: 0.025010999999999999 Warmup Iters: 2 Fastest Mean: 0.012015649999999999 Slowest Mean: 0.012421204081632654 ``` The unit's are just aliases for string constants, so you can for example: ```mojo print(report.mean("ms")) ``` ```output 12.199145299145298 ``` Benchmark.run takes four arguments to change the behaviour, to set warmup iterations to 5: ```mojo r = benchmark.run[sleeper](5) ``` ```output 0.012004808080808081 ``` To set 1 warmup iteration, 2 max iterations, a min total time of 3 sec, and a max total time of 4 s: ```mojo r = benchmark.run[sleeper](1, 2, 3, 4) ``` Note that the min total time will take precedence over max iterations ## Modules * [​`bencher`](/mojo/stdlib/benchmark/bencher/): * [​`benchmark`](/mojo/stdlib/benchmark/benchmark/): Implements the benchmark module for runtime benchmarking. * [​`compiler`](/mojo/stdlib/benchmark/compiler/): * [​`memory`](/mojo/stdlib/benchmark/memory/): * [​`quick_bench`](/mojo/stdlib/benchmark/quick_bench/):
--- ## clobber_memory
`clobber_memory()` Forces all pending memory writes to be flushed to memory. This ensures that the compiler does not optimize away memory writes if it deems them to be not necessary. In effect, this operation acts as a barrier to memory reads and writes.
--- ## memory (Memory)
## Functions * [​`clobber_memory`](/mojo/stdlib/benchmark/memory/clobber_memory): Forces all pending memory writes to be flushed to memory.
--- ## QuickBench
`struct QuickBench` Defines a struct to facilitate benchmarking and avoiding `Bencher` boilerplate. ## Fields * ​m (`Bench`): Bench object to collect the results. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ## Methods ### `__init__` `__init__(out self)` Initializes the `Bench` object. **Raises:** If the operation fails. ### `dump_report` `dump_report(mut self)` Prints out the report from a Benchmark execution collected in Bench object. **Raises:** If the operation fails. ### `run` `run[T_out: AnyTrivialRegType](mut self, func: fn() -> T_out, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with no input arguments and return type `T_out`. **Parameters:** * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn() -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0) -> T_out, x0: T0, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 1 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1) -> T_out, x0: T0, x1: T1, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 2 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, T2: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1, T2) -> T_out, x0: T0, x1: T1, x2: T2, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 3 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T2 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 3rd argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1, T2) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​x2 (`T2`): The 3rd argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, T2: AnyTrivialRegType, T3: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1, T2, T3) -> T_out, x0: T0, x1: T1, x2: T2, x3: T3, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 4 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T2 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 3rd argument of func. * ​T3 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 4th argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1, T2, T3) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​x2 (`T2`): The 3rd argument of func. * ​x3 (`T3`): The 4th argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, T2: AnyTrivialRegType, T3: AnyTrivialRegType, T4: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1, T2, T3, T4) -> T_out, x0: T0, x1: T1, x2: T2, x3: T3, x4: T4, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 5 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T2 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 3rd argument of func. * ​T3 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 4th argument of func. * ​T4 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 5th argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1, T2, T3, T4) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​x2 (`T2`): The 3rd argument of func. * ​x3 (`T3`): The 4th argument of func. * ​x4 (`T4`): The 5th argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, T2: AnyTrivialRegType, T3: AnyTrivialRegType, T4: AnyTrivialRegType, T5: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1, T2, T3, T4, T5) -> T_out, x0: T0, x1: T1, x2: T2, x3: T3, x4: T4, x5: T5, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 6 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T2 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 3rd argument of func. * ​T3 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 4th argument of func. * ​T4 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 5th argument of func. * ​T5 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 6th argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1, T2, T3, T4, T5) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​x2 (`T2`): The 3rd argument of func. * ​x3 (`T3`): The 4th argument of func. * ​x4 (`T4`): The 5th argument of func. * ​x5 (`T5`): The 6th argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, T2: AnyTrivialRegType, T3: AnyTrivialRegType, T4: AnyTrivialRegType, T5: AnyTrivialRegType, T6: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1, T2, T3, T4, T5, T6) -> T_out, x0: T0, x1: T1, x2: T2, x3: T3, x4: T4, x5: T5, x6: T6, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 7 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T2 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 3rd argument of func. * ​T3 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 4th argument of func. * ​T4 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 5th argument of func. * ​T5 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 6th argument of func. * ​T6 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 7th argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1, T2, T3, T4, T5, T6) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​x2 (`T2`): The 3rd argument of func. * ​x3 (`T3`): The 4th argument of func. * ​x4 (`T4`): The 5th argument of func. * ​x5 (`T5`): The 6th argument of func. * ​x6 (`T6`): The 7th argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, T2: AnyTrivialRegType, T3: AnyTrivialRegType, T4: AnyTrivialRegType, T5: AnyTrivialRegType, T6: AnyTrivialRegType, T7: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1, T2, T3, T4, T5, T6, T7) -> T_out, x0: T0, x1: T1, x2: T2, x3: T3, x4: T4, x5: T5, x6: T6, x7: T7, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 8 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T2 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 3rd argument of func. * ​T3 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 4th argument of func. * ​T4 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 5th argument of func. * ​T5 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 6th argument of func. * ​T6 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 7th argument of func. * ​T7 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 8th argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1, T2, T3, T4, T5, T6, T7) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​x2 (`T2`): The 3rd argument of func. * ​x3 (`T3`): The 4th argument of func. * ​x4 (`T4`): The 5th argument of func. * ​x5 (`T5`): The 6th argument of func. * ​x6 (`T6`): The 7th argument of func. * ​x7 (`T7`): The 8th argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, T2: AnyTrivialRegType, T3: AnyTrivialRegType, T4: AnyTrivialRegType, T5: AnyTrivialRegType, T6: AnyTrivialRegType, T7: AnyTrivialRegType, T8: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1, T2, T3, T4, T5, T6, T7, T8) -> T_out, x0: T0, x1: T1, x2: T2, x3: T3, x4: T4, x5: T5, x6: T6, x7: T7, x8: T8, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 9 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T2 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 3rd argument of func. * ​T3 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 4th argument of func. * ​T4 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 5th argument of func. * ​T5 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 6th argument of func. * ​T6 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 7th argument of func. * ​T7 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 8th argument of func. * ​T8 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 9th argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1, T2, T3, T4, T5, T6, T7, T8) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​x2 (`T2`): The 3rd argument of func. * ​x3 (`T3`): The 4th argument of func. * ​x4 (`T4`): The 5th argument of func. * ​x5 (`T5`): The 6th argument of func. * ​x6 (`T6`): The 7th argument of func. * ​x7 (`T7`): The 8th argument of func. * ​x8 (`T8`): The 9th argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails. `run[T0: AnyTrivialRegType, T1: AnyTrivialRegType, T2: AnyTrivialRegType, T3: AnyTrivialRegType, T4: AnyTrivialRegType, T5: AnyTrivialRegType, T6: AnyTrivialRegType, T7: AnyTrivialRegType, T8: AnyTrivialRegType, T9: AnyTrivialRegType, /, T_out: AnyTrivialRegType](mut self, func: fn(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) -> T_out, x0: T0, x1: T1, x2: T2, x3: T3, x4: T4, x5: T5, x6: T6, x7: T7, x8: T8, x9: T9, *, bench_id: BenchId, measures: List[ThroughputMeasure] = List[ThroughputMeasure]())` Benchmark function `func` with 10 input argument and return type `T_out`. **Parameters:** * ​T0 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 1st argument of func. * ​T1 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 2nd argument of func. * ​T2 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 3rd argument of func. * ​T3 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 4th argument of func. * ​T4 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 5th argument of func. * ​T5 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 6th argument of func. * ​T6 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 7th argument of func. * ​T7 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 8th argument of func. * ​T8 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 9th argument of func. * ​T9 ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the 10th argument of func. * ​T\_out ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Output type of func. **Args:** * ​func (`fn(T0, T1, T2, T3, T4, T5, T6, T7, T8, T9) -> T_out`): The function to be benchmarked (run in benchmark iterations). * ​x0 (`T0`): The 1st argument of func. * ​x1 (`T1`): The 2nd argument of func. * ​x2 (`T2`): The 3rd argument of func. * ​x3 (`T3`): The 4th argument of func. * ​x4 (`T4`): The 5th argument of func. * ​x5 (`T5`): The 6th argument of func. * ​x6 (`T6`): The 7th argument of func. * ​x7 (`T7`): The 8th argument of func. * ​x8 (`T8`): The 9th argument of func. * ​x9 (`T9`): The 10th argument of func. * ​bench\_id ([`BenchId`](/mojo/stdlib/benchmark/bencher/BenchId)): The benchmark Id object used for identification. * ​measures ([`List`](/mojo/stdlib/collections/list/List)): Optional arg used to represent a list of ThroughputMeasure's. **Raises:** If the operation fails.
--- ## quick_bench
## Structs * [​`QuickBench`](/mojo/stdlib/benchmark/quick_bench/QuickBench): Defines a struct to facilitate benchmarking and avoiding `Bencher` boilerplate.
--- ## bit_not
`bit_not[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Performs a bitwise NOT operation on an SIMD vector of integer values. **Constraints:** The element type of the input vector must be integral. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `dtype` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` is computed as a bitwise NOT of the integer value at position `i` of the input value.
--- ## bit_reverse
`bit_reverse(val: Int) -> Int` Reverses the bitpattern of an integer value. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The input value with its bitpattern reversed. `bit_reverse[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Element-wise reverses the bitpattern of a SIMD vector of integer values. **Constraints:** The element type of the input vector must be integral. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `dtype` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` has a reversed bitpattern of an integer value of the element at position `i` of the input value.
--- ## bit_width
`bit_width(val: Int) -> Int` Computes the minimum number of bits required to represent the integer. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of bits required to represent the integer. `bit_width[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Computes the minimum number of bits required to represent each element of a SIMD vector of integer values. **Constraints:** The element type of the input vector must be integral. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `dtype` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` equals the number of bits required to represent the integer at position `i` of the input.
--- ## byte_swap
`byte_swap(val: Int) -> Int` Byte-swaps an integer value with an even number of bytes. Byte swap an integer value (8 bytes) with an even number of bytes (positive multiple of 16 bits). This returns an integer value (8 bytes) that has its bytes swapped. For example, if the input bytes are numbered 0, 1, 2, 3, 4, 5, 6, 7 then the returned integer will have its bytes in 7, 6, 5, 4, 3, 2, 1, 0 order. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The input value with its bytes swapped. `byte_swap[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Byte-swaps a SIMD vector of integer values with an even number of bytes. Byte swap an integer value or vector of integer values with an even number of bytes (positive multiple of 16 bits). For example, The Int16 returns an Int16 value that has the high and low byte of the input Int16 swapped. Similarly, Int32 returns an Int32 value that has the four bytes of the input Int32 swapped, so that if the input bytes are numbered 0, 1, 2, 3 then the returned Int32 will have its bytes in 3, 2, 1, 0 order. Int64 and other integer type extend this concept to additional even-byte lengths (6 bytes, 8 bytes and more, respectively). **Constraints:** The element type of the input vector must be an integral type. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `dtype` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` is the value of the element at position `i` of the input value with its bytes swapped.
--- ## count_leading_zeros
`count_leading_zeros(val: Int) -> Int` Counts the number of leading zeros of an integer. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of leading zeros of the input. `count_leading_zeros[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Counts the per-element number of leading zeros in a SIMD vector. **Constraints:** The element type of the input vector must be integral. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `DType` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` contains the number of leading zeros at position `i` of the input value.
--- ## count_trailing_zeros
`count_trailing_zeros(val: Int) -> Int` Counts the number of trailing zeros for an integer. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of trailing zeros of the input. `count_trailing_zeros[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Counts the per-element number of trailing zeros in a SIMD vector. **Constraints:** The element type of the input vector must be integral. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `dtype` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` contains the number of trailing zeros at position `i` of the input value.
--- ## bit
Provides functions for bit manipulation. You can import these APIs from the `bit` package. For example: ```mojo from bit import count_leading_zeros ``` ## Functions * [​`bit_not`](/mojo/stdlib/bit/bit/bit_not): Performs a bitwise NOT operation on an SIMD vector of integer values. * [​`bit_reverse`](/mojo/stdlib/bit/bit/bit_reverse): Reverses the bitpattern of an integer value. * [​`bit_width`](/mojo/stdlib/bit/bit/bit_width): Computes the minimum number of bits required to represent the integer. * [​`byte_swap`](/mojo/stdlib/bit/bit/byte_swap): Byte-swaps an integer value with an even number of bytes. * [​`count_leading_zeros`](/mojo/stdlib/bit/bit/count_leading_zeros): Counts the number of leading zeros of an integer. * [​`count_trailing_zeros`](/mojo/stdlib/bit/bit/count_trailing_zeros): Counts the number of trailing zeros for an integer. * [​`log2_ceil`](/mojo/stdlib/bit/bit/log2_ceil): Returns the ceiling of the base-2 logarithm of an integer value. * [​`log2_floor`](/mojo/stdlib/bit/bit/log2_floor): Returns the floor of the base-2 logarithm of an integer value. * [​`next_power_of_two`](/mojo/stdlib/bit/bit/next_power_of_two): Computes the smallest power of 2 that is greater than or equal to the input value. Any integral value less than or equal to 1 will be ceiled to 1. * [​`pop_count`](/mojo/stdlib/bit/bit/pop_count): Counts the number of bits set in an integer value. * [​`prev_power_of_two`](/mojo/stdlib/bit/bit/prev_power_of_two): Computes the largest power of 2 that is less than or equal to the input value. Any integral value less than or equal to 0 will be floored to 0. * [​`rotate_bits_left`](/mojo/stdlib/bit/bit/rotate_bits_left): Shifts the bits of an input to the left by `shift` bits (with wrap-around). * [​`rotate_bits_right`](/mojo/stdlib/bit/bit/rotate_bits_right): Shifts the bits of an input to the right by `shift` bits (with wrap-around).
--- ## log2_ceil
`log2_ceil(val: Int) -> Int` Returns the ceiling of the base-2 logarithm of an integer value. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The ceiling of the base-2 logarithm of the input value, which corresponds to the smallest power of 2 greater than or equal to the input. Returns 0 if val is 0. `log2_ceil(val: Scalar[dtype]) -> Scalar[dtype]` Returns the ceiling of the base-2 logarithm of an integer value. **Args:** * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The input value. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The smallest integer `n` such that `2^n` is greater than or equal to the input value. Returns 0 if `val` is 0.
--- ## log2_floor
`log2_floor(val: Int) -> Int` Returns the floor of the base-2 logarithm of an integer value. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The floor of the base-2 logarithm of the input value, which is equal to the position of the highest set bit. Returns -1 if val is 0 or negative. `log2_floor(val: UInt) -> UInt` Returns the floor of the base-2 logarithm of an integer value. **Args:** * ​val ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The input value. **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): The floor of the base-2 logarithm of the input value, which is equal to the position of the highest set bit. Returns UInt.MAX if val is 0. `log2_floor[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Returns the floor of the base-2 logarithm of an integer value. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The `dtype` of the input SIMD vector. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the input and output SIMD vector. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The floor of the base-2 logarithm of the input value, which is equal to the position of the highest set bit. Returns -1 if val is 0 or negative.
--- ## next_power_of_two
`next_power_of_two(val: Int) -> Int` Computes the smallest power of 2 that is greater than or equal to the input value. Any integral value less than or equal to 1 will be ceiled to 1. Notes: This operation is called `bit_ceil()` in C++. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The smallest power of 2 that is greater than or equal to the input value. `next_power_of_two(val: UInt) -> UInt` Computes the smallest power of 2 that is greater than or equal to the input value. Any integral value less than or equal to 1 will be ceiled to 1. Notes: This operation is called `bit_ceil()` in C++. **Args:** * ​val ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The input value. **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): The smallest power of 2 that is greater than or equal to the input value. `next_power_of_two[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Computes the smallest power of 2 that is greater than or equal to the input value for each element of a SIMD vector. Any integral value less than or equal to 1 will be ceiled to 1. This operation is called `bit_ceil()` in C++. **Constraints:** The element type of the input vector must be integral. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `dtype` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` is the smallest power of 2 that is greater than or equal to the integer at position `i` of the input value.
--- ## pop_count
`pop_count(val: Int) -> Int` Counts the number of bits set in an integer value. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of bits set in the input value. `pop_count[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Counts the number of bits set in a SIMD vector of integer values. **Constraints:** The element type of the input vector must be integral. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `dtype` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` contains the number of bits set in the element at position `i` of the input value.
--- ## prev_power_of_two
`prev_power_of_two(val: Int) -> Int` Computes the largest power of 2 that is less than or equal to the input value. Any integral value less than or equal to 0 will be floored to 0. This operation is called `bit_floor()` in C++. **Args:** * ​val ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The largest power of 2 that is less than or equal to the input value. `prev_power_of_two[dtype: DType, width: Int, //](val: SIMD[dtype, width]) -> SIMD[dtype, width]` Computes the largest power of 2 that is less than or equal to the input value for each element of a SIMD vector. Any integral value less than or equal to 0 will be floored to 0. This operation is called `bit_floor()` in C++. **Constraints:** The element type of the input vector must be integral. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): `dtype` used for the computation. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width used for the computation. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The input value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` is the largest power of 2 that is less than or equal to the integer at position `i` of the input value.
--- ## rotate_bits_left
`rotate_bits_left[shift: Int](x: Int) -> Int` Shifts the bits of an input to the left by `shift` bits (with wrap-around). **Constraints:** `-size <= shift < size` **Parameters:** * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of bit positions by which to rotate the bits of the integer to the left (with wrap-around). **Args:** * ​x ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The input rotated to the left by `shift` elements (with wrap-around). `rotate_bits_left[dtype: DType, width: Int, //, shift: Int](x: SIMD[dtype, width]) -> SIMD[dtype, width]` Shifts bits to the left by `shift` positions (with wrap-around) for each element of a SIMD vector. **Constraints:** `0 <= shift < size` **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The `dtype` of the input and output SIMD vector. Must be integral and unsigned. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the SIMD vector. * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of positions to rotate left. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): SIMD vector input. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector with each element rotated left by `shift` bits.
--- ## rotate_bits_right
`rotate_bits_right[shift: Int](x: Int) -> Int` Shifts the bits of an input to the right by `shift` bits (with wrap-around). **Constraints:** `-size <= shift < size` **Parameters:** * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of bit positions by which to rotate the bits of the integer to the right (with wrap-around). **Args:** * ​x ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The input rotated to the right by `shift` elements (with wrap-around). `rotate_bits_right[dtype: DType, width: Int, //, shift: Int](x: SIMD[dtype, width]) -> SIMD[dtype, width]` Shifts bits to the right by `shift` positions (with wrap-around) for each element of a SIMD vector. **Constraints:** `0 <= shift < size` **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The `dtype` of the input and output SIMD vector. Must be integral and unsigned. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the SIMD vector. * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of positions to rotate right. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): SIMD vector input. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector with each element rotated right by `shift` bits.
--- ## bit (Bit)
Implements the bit package. ## Modules * [​`bit`](/mojo/stdlib/bit/bit/): Provides functions for bit manipulation.
--- ## AnyType
A trait for types that require lifetime management through destructors. The `AnyType` trait is fundamental to Mojo's memory management system. It indicates that a type has a destructor that needs to be called when instances go out of scope. This is essential for types that own resources like memory, file handles, or other system resources that need proper cleanup. Key aspects: * Any type with a destructor must implement this trait * The destructor (`__del__`) is called automatically when an instance's lifetime ends * Composition of types with destructors automatically gets a destructor * All Mojo structs and traits inherit from `AnyType` by default unless they specify `@explicit_destroy` Example: ```mojo struct ResourceOwner(AnyType): var ptr: UnsafePointer[Int] fn __init__(out self, size: Int): self.ptr = UnsafePointer[Int].alloc(size) fn __del__(deinit self): # Clean up owned resources self.ptr.free() ``` Best practices: * Implement this trait when your type owns resources that need cleanup * Ensure the destructor properly frees all owned resources * Consider using `@explicit_destroy` for types that should never have destructors * Use composition to automatically handle nested resource cleanup ## Implemented traits [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__del__` `__del__(deinit self: _Self, /)` Destroys the instance and cleans up any owned resources. This method is called automatically when an instance's lifetime ends. It receives an owned value and should perform all necessary cleanup operations like: * Freeing allocated memory * Closing file handles * Releasing system resources * Cleaning up any other owned resources The instance is considered dead after this method completes, regardless of whether any explicit cleanup was performed.
--- ## UnknownDestructibility
The most basic trait that all Mojo types extend by default. This trait indicates that a type has no destructor and therefore no lifetime management. It is the default for all types unless they explicitly implement `AnyType` or `ImplicitlyDestructible`. Types with this trait: * Have no `__del__` method * Do not perform any cleanup when they go out of scope * Are suitable for simple value types that don't own resources For types that need cleanup when they are destroyed, use `ImplicitlyDestructible` or `AnyType` instead.
--- ## anytype (Anytype)
Defines the core traits for object lifetime management in Mojo. This module provides the foundational traits that define how objects are created, managed and destroyed in Mojo: * `UnknownDestructibility`: The most basic trait that all types extend by default. Types with this trait have no destructor and no lifetime management. * `AnyType`: The base trait for types that require lifetime management through destructors. Any type that needs cleanup when it goes out of scope should implement this trait. * `ImplicitlyDestructible`: An alias for `AnyType` to help with the transition to linear types. Use this when you want to be explicit about a type having a destructor. These traits are built into Mojo and do not need to be imported. ## `comptime` values ### `__SomeImpl` `comptime __SomeImpl[Trait: AnyTrivialRegType, T: Trait] = T` #### Parameters * ​Trait ([`AnyTrivialRegType`](/stdlib/builtin/type_aliases/#anytrivialregtype)): * ​T (`Trait`): ### `ImplicitlyDestructible` `comptime ImplicitlyDestructible = AnyType` Temporary alias for types that can be implicitly destroyed. ### `Some` `comptime Some[Trait: AnyTrivialRegType] = alias[T: Trait] T` An alias allowing users to tersely express that a function argument is an instance of a type that implements a trait or trait composition. For example, instead of writing ```mojo fn foo[T: Intable, //](x: T) -> Int: return x.__int__() ``` one can write: ```mojo fn foo(x: Some[Intable]) -> Int: return x.__int__() ``` #### Parameters * ​Trait ([`AnyTrivialRegType`](/stdlib/builtin/type_aliases/#anytrivialregtype)): The trait or trait composition that the argument type must implement. ## Traits * [​`AnyType`](/mojo/stdlib/builtin/anytype/AnyType): A trait for types that require lifetime management through destructors. * [​`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility): The most basic trait that all Mojo types extend by default.
--- ## Bool
`@register_passable(trivial)` `struct Bool` The primitive Bool scalar value used in Mojo. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`ConvertibleFromPython`](/mojo/stdlib/python/conversions/ConvertibleFromPython), [`ConvertibleToPython`](/mojo/stdlib/python/conversions/ConvertibleToPython), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Floatable`](/mojo/stdlib/builtin/floatable/Floatable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Indexer`](/mojo/stdlib/builtin/int/Indexer), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `MAX` `comptime MAX = Bool.__init__(True)` The maximum value of a Bool. ### `MIN` `comptime MIN = Bool.__init__(False)` The minimum value of a Bool. ## Methods ### `__init__` `__init__() -> Self` Construct a default, `False` Bool. `__init__[T: Boolable, //](value: T) -> Self` Set the bool representation of the object. **Parameters:** * ​T ([`Boolable`](/mojo/stdlib/builtin/bool/Boolable)): The type of the object. **Args:** * ​value (`T`): The object to get the bool representation of. `__init__(value: None) -> Self` Set the bool representation of the `None` type to `False`. **Args:** * ​value (`None`): The object to get the bool representation of. `@implicit` `__init__(value: Scalar[DType.bool]) -> Self` Convert a scalar SIMD value to a Bool. **Args:** * ​value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The scalar value. ### `__bool__` `__bool__(self) -> Self` Convert to Bool. **Returns:** `Self`: This value. ### `__invert__` `__invert__(self) -> Self` Inverts the Bool value. **Returns:** `Self`: True if the object is false and False otherwise. ### `__lt__` `__lt__(self, rhs: Self) -> Self` Compare this Bool to RHS using less-than comparison. **Args:** * ​rhs (`Self`): The rhs of the operation. **Returns:** `Self`: True if self is False and rhs is True. ### `__le__` `__le__(self, rhs: Self) -> Self` Compare this Bool to RHS using less-than-or-equal comparison. **Args:** * ​rhs (`Self`): The rhs of the operation. **Returns:** `Self`: True if self is False and rhs is True or False. ### `__eq__` `__eq__(self, rhs: Self) -> Self` Compare this Bool to RHS. Performs an equality comparison between the Bool value and the argument. This method gets invoked when a user uses the `==` infix operator. **Args:** * ​rhs (`Self`): The rhs value of the equality statement. **Returns:** `Self`: True if the two values match and False otherwise. ### `__ne__` `__ne__(self, rhs: Self) -> Self` Compare this Bool to RHS. Performs a non-equality comparison between the Bool value and the argument. This method gets invoked when a user uses the `!=` infix operator. **Args:** * ​rhs (`Self`): The rhs value of the non-equality statement. **Returns:** `Self`: False if the two values do match and True otherwise. ### `__gt__` `__gt__(self, rhs: Self) -> Self` Compare this Bool to RHS using greater-than comparison. **Args:** * ​rhs (`Self`): The rhs of the operation. **Returns:** `Self`: True if self is True and rhs is False. ### `__ge__` `__ge__(self, rhs: Self) -> Self` Compare this Bool to RHS using greater-than-or-equal comparison. **Args:** * ​rhs (`Self`): The rhs of the operation. **Returns:** `Self`: True if self is True and rhs is True or False. ### `__and__` `__and__(self, rhs: Self) -> Self` Returns `self & rhs`. Bitwise and's the Bool value with the argument. This method gets invoked when a user uses the `and` infix operator. **Args:** * ​rhs (`Self`): The right hand side of the `and` statement. **Returns:** `Self`: `self & rhs`. ### `__or__` `__or__(self, rhs: Self) -> Self` Returns `self | rhs`. Bitwise or's the Bool value with the argument. This method gets invoked when a user uses the `or` infix operator. **Args:** * ​rhs (`Self`): The right hand side of the `or` statement. **Returns:** `Self`: `self | rhs`. ### `__xor__` `__xor__(self, rhs: Self) -> Self` Returns `self ^ rhs`. Bitwise Xor's the Bool value with the argument. This method gets invoked when a user uses the `^` infix operator. **Args:** * ​rhs (`Self`): The right hand side of the `xor` statement. **Returns:** `Self`: `self ^ rhs`. ### `__rand__` `__rand__(self, lhs: Self) -> Self` Returns `lhs & self`. **Args:** * ​lhs (`Self`): The left hand side of the `and` statement. **Returns:** `Self`: `lhs & self`. ### `__ror__` `__ror__(self, lhs: Self) -> Self` Returns `lhs | self`. **Args:** * ​lhs (`Self`): The left hand side of the `or` statement. **Returns:** `Self`: `lhs | self`. ### `__rxor__` `__rxor__(self, lhs: Self) -> Self` Returns `lhs ^ self`. **Args:** * ​lhs (`Self`): The left hand side of the `xor` statement. **Returns:** `Self`: `lhs ^ self`. ### `__iand__` `__iand__(mut self, rhs: Self)` Computes `self & rhs` and store the result in `self`. **Args:** * ​rhs (`Self`): The right hand side of the `and` statement. ### `__ixor__` `__ixor__(mut self, rhs: Self)` Computes `self ^ rhs` and stores the result in `self`. **Args:** * ​rhs (`Self`): The right hand side of the `xor` statement. ### `__ior__` `__ior__(mut self, rhs: Self)` Computes `self | rhs` and store the result in `self`. **Args:** * ​rhs (`Self`): The right hand side of the `or` statement. ### `__str__` `__str__(self) -> String` Get the bool as a string. Returns `"True"` or `"False"`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation. ### `write_to` `write_to(self, mut writer: T)` Formats this boolean to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `__repr__` `__repr__(self) -> String` Get the bool as a string. Returns `"True"` or `"False"`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation. ### `__int__` `__int__(self) -> Int` Convert this Bool to an integer. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): 1 if the Bool is True, 0 otherwise. ### `__as_int__` `__as_int__(self) -> Int` Implicitly convert to an integral representation of the value, wherever an `Int` is expected. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integral representation of the value. ### `__mlir_index__` `__mlir_index__(self) -> __mlir_type.index` Convert to index. **Returns:** `__mlir_type.index`: 1 if the Bool is True, 0 otherwise. ### `__float__` `__float__(self) -> Float64` Convert this Bool to a float. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): 1.0 if True else 0.0 otherwise. ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with the underlying bytes. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance. ### `to_python_object` `to_python_object(var self) -> PythonObject` Convert this value to a PythonObject. **Returns:** `PythonObject`: A PythonObject representing the value. **Raises:** If the Python runtime is not initialized or conversion fails.
--- ## Boolable
The `Boolable` trait describes a type that can be explicitly converted to a `Bool` or evaluated as a boolean expression in `if` or `while` conditions. This trait requires the type to implement the `__bool__()` method. For example: ```mojo struct Foo(Boolable): var val: Bool fn __bool__(self) -> Bool: return self.val ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__bool__` `__bool__(self: _Self) -> Bool` Get the boolean representation of the value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): The boolean representation of the value.
--- ## all
`all[IterableType: Iterable](iterable: IterableType) -> Bool` Checks if **all** elements in the list are truthy. **Parameters:** * ​IterableType ([`Iterable`](/mojo/stdlib/iter/Iterable)): The type of the iterable containing `Boolable` items. **Args:** * ​iterable (`IterableType`): The iterable to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if **all** elements in the iterable are truthy, `False` otherwise. `all(value: SIMD[dtype, size]) -> Bool` Checks if **all** elements in the simd vector are truthy. **Args:** * ​value ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The simd vector to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if **all** elements in the simd vector are truthy, `False` otherwise.
--- ## any
`any[IterableType: Iterable](iterable: IterableType) -> Bool` Checks if **all** elements in the list are truthy. **Parameters:** * ​IterableType ([`Iterable`](/mojo/stdlib/iter/Iterable)): The type of the iterable containing `Boolable` items. **Args:** * ​iterable (`IterableType`): The iterable to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if **any** element in the list is truthy, `False` otherwise. `any(value: SIMD[dtype, size]) -> Bool` Checks if **any** element in the simd vector is truthy. **Args:** * ​value ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The simd vector to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if **any** element in the simd vector is truthy, `False` otherwise.
--- ## bool (Bool)
Implements the Bool class. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`Bool`](/mojo/stdlib/builtin/bool/Bool): The primitive Bool scalar value used in Mojo. ## Traits * [​`Boolable`](/mojo/stdlib/builtin/bool/Boolable): The `Boolable` trait describes a type that can be explicitly converted to a `Bool` or evaluated as a boolean expression in `if` or `while` conditions. ## Functions * [​`all`](/mojo/stdlib/builtin/bool/all): Checks if **all** elements in the list are truthy. * [​`any`](/mojo/stdlib/builtin/bool/any): Checks if **all** elements in the list are truthy.
--- ## breakpoint
`breakpoint()` Cause an execution trap with the intention of requesting the attention of a debugger.
--- ## breakpoint (Breakpoint)
This module includes the builtin breakpoint function. ## Functions * [​`breakpoint`](/mojo/stdlib/builtin/breakpoint/breakpoint): Cause an execution trap with the intention of requesting the attention of a debugger.
--- ## ContiguousSlice
`struct ContiguousSlice` Represents a slice expression without a stride. This type is used to support different behavior for strided vs unstrided slicing. ## Fields * ​start (`Optional[Int]`): The starting index of the slice. * ​end (`Optional[Int]`): The end index of the slice. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `__init__(out self, start: Optional[Int], end: Optional[Int], stride: NoneType)` Construct slice given the start and end values. **Args:** * ​start ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The start value. * ​end ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The end value. * ​stride ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): Always none. Disambiguates from slices with a stride. ### `indices` `indices(self, length: Int) -> Tuple[Int, Int]` Returns a tuple of 2 integers representing the start, and end of the slice if applied to a container of the given length. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The length of the target container. **Returns:** `Tuple`: A tuple containing two integers for start and and.
--- ## Slice (Builtin_slice)
`struct Slice` Represents a slice expression. Objects of this type are generated when slice syntax is used within square brackets, e.g.: ```mojo var lst: List[Int] = [0,1,2,3,4,5,6,7] # Both are equivalent and result in a list: []. var l1 = List(lst[6:]) var l2 = lst.__getitem__(Slice(6, len(lst))) ``` ## Fields * ​start (`Optional[Int]`): The starting index of the slice. * ​end (`Optional[Int]`): The end index of the slice. * ​step (`Optional[Int]`): The step increment value of the slice. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `__init__(out self, start: Int, end: Int)` Construct slice given the start and end values. **Args:** * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The start value. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end value. `__init__(out self, start: Optional[Int], end: Optional[Int], step: Optional[Int])` Construct slice given the start, end and step values. **Args:** * ​start ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The start value. * ​end ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The end value. * ​step ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The step value. ### `__eq__` `__eq__(self, other: Self) -> Bool` Compare this slice to the other. **Args:** * ​other (`Self`): The slice to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if start, end, and step values of this slice match the corresponding values of the other slice and False otherwise. ### `__str__` `__str__(self) -> String` Gets the string representation of the span. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the span. ### `__repr__` `__repr__(self) -> String` Gets the string representation of the span. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the span. ### `write_to` `write_to(self, mut writer: T)` Write Slice string representation to a `Writer`. **Args:** * ​writer (`T`): The object to write to. ### `indices` `indices(self, length: Int) -> Tuple[Int, Int, Int]` Returns a tuple of 3 integers representing the start, end, and step of the slice if applied to a container of the given length. Uses the target container length to normalize negative, out of bounds, or None indices. Negative indices are wrapped using the length of the container. ```mojo s = slice(0, -1, 1) i = s.indices(5) # returns (0, 4, 1) ``` None indices are defaulted to the start or the end of the container based on whether `step` is positive or negative. ```mojo s = slice(None, None, 1) i = s.indices(5) # returns (0, 5, 1) ``` Out of bounds indices are clamped using the size of the container. ```mojo s = slice(20) i = s.indices(5) # returns (0, 5, 1) ``` **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The length of the target container. **Returns:** `Tuple`: A tuple containing three integers for start, end, and step.
--- ## StridedSlice
`struct StridedSlice` Represents a slice expression that has a stride. This type is used to support different behavior for strided vs unstrided slicing. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `@implicit` `__init__(out self, other: Slice)` Implicitly convert from a general slice. **Args:** * ​other ([`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice)): The other slice. `__init__(out self, start: Optional[Int], end: Optional[Int], stride: Int)` Construct slice given start, end, and stride values. **Args:** * ​start ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The start value. * ​end ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The end value. * ​stride ([`Int`](/mojo/stdlib/builtin/int/Int)): The step value. ### `indices` `indices(self, length: Int) -> Tuple[Int, Int, Int]` Returns a tuple of 3 integers representing start, end, and step of the slice if applied to a container of given length. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The length of the target container. **Returns:** `Tuple`: A tuple containing three integers for start, end, and step.
--- ## builtin_slice
Implements slice. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`ContiguousSlice`](/mojo/stdlib/builtin/builtin_slice/ContiguousSlice): Represents a slice expression without a stride. * [​`Slice`](/mojo/stdlib/builtin/builtin_slice/Slice): Represents a slice expression. * [​`StridedSlice`](/mojo/stdlib/builtin/builtin_slice/StridedSlice): Represents a slice expression that has a stride. ## Functions * [​`slice`](/mojo/stdlib/builtin/builtin_slice/slice-function): Construct slice given the end value.
--- ## slice (3)
`slice(end: Int) -> Slice` Construct slice given the end value. **Args:** * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end value. **Returns:** `Slice`: The constructed slice. `slice(start: Int, end: Int) -> Slice` Construct slice given the start and end values. **Args:** * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The start value. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end value. **Returns:** `Slice`: The constructed slice. `slice(start: Optional[Int], end: Optional[Int], step: Optional[Int]) -> Slice` Construct a Slice given the start, end and step values. **Args:** * ​start ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The start value. * ​end ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The end value. * ​step ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The step value. **Returns:** `Slice`: The constructed slice.
--- ## Comparable
A type which can be compared for order with other instances of itself. Implementers of this trait must define the `__lt__` and `__eq__` methods. The default implementations of the default comparison methods can be potentially inefficent for types where comparison is expensive. For such types, it is recommended to override all the default implementations. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__lt__` `__lt__(self: _Self, rhs: _Self) -> Bool` Define whether `self` is less than `rhs`. **Args:** * ​rhs (`_Self`): The value to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is less than `rhs`. ### `__eq__` `__eq__(self: _Self, other: _Self) -> Bool` Define whether two instances of the object are equal to each other. **Args:** * ​other (`_Self`): Another instance of the same type. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the instances are equal according to the type's definition of equality, False otherwise. ## Provided methods ### `__le__` `__le__(self: _Self, rhs: _Self) -> Bool` Define whether `self` is less than or equal to `rhs`. **Args:** * ​rhs (`_Self`): The value to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is less than or equal to `rhs`. ### `__ne__` `__ne__(self: _Self, other: _Self) -> Bool` Define whether two instances of the object are not equal to each other. **Args:** * ​other (`_Self`): Another instance of the same type. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the instances are not equal according to the type's definition of equality, False otherwise. ### `__gt__` `__gt__(self: _Self, rhs: _Self) -> Bool` Define whether `self` is greater than `rhs`. **Args:** * ​rhs (`_Self`): The value to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is greater than `rhs`. ### `__ge__` `__ge__(self: _Self, rhs: _Self) -> Bool` Define whether `self` is greater than or equal to `rhs`. **Args:** * ​rhs (`_Self`): The value to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is greater than or equal to `rhs`.
--- ## Equatable
A type which can be compared for equality with other instances of itself. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__eq__` `__eq__(self: _Self, other: _Self) -> Bool` Define whether two instances of the object are equal to each other. **Args:** * ​other (`_Self`): Another instance of the same type. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the instances are equal according to the type's definition of equality, False otherwise. ## Provided methods ### `__ne__` `__ne__(self: _Self, other: _Self) -> Bool` Define whether two instances of the object are not equal to each other. **Args:** * ​other (`_Self`): Another instance of the same type. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the instances are not equal according to the type's definition of equality, False otherwise.
--- ## comparable (Comparable)
## `comptime` values ### `EqualityComparable` `comptime EqualityComparable = Equatable` Deprecated alias for `Equatable`. **Deprecated:** 'EqualityComparable' is deprecated, use 'Equatable' instead ## Traits * [​`Comparable`](/mojo/stdlib/builtin/comparable/Comparable): A type which can be compared for order with other instances of itself. * [​`Equatable`](/mojo/stdlib/builtin/comparable/Equatable): A type which can be compared for equality with other instances of itself.
--- ## constrained
`constrained[cond: Bool, msg: StringSlice[StaticConstantOrigin], *extra: StringSlice[StaticConstantOrigin]]()` Asserts that the condition must be true at compile time. The `constrained()` function introduces a compile-time constraint on the enclosing function. If the condition is true at compile time, the constraint has no effect. If the condition is false, compilation fails and the message is displayed. This is similar to `static_assert` in C++. It differs from [`debug_assert()`](/mojo/stdlib/builtin/debug_assert/debug_assert), which is a run-time assertion. Example: ```mojo fn half[dtype: DType](a: Scalar[dtype]) -> Scalar[dtype]: __comptime_assert dtype.is_numeric(), "dtype must be numeric." return a / 2 def main(): print(half(UInt8(5))) # prints 2 print(half(Scalar[DType.bool](True))) # constraint failed: # dtype must be numeric. ``` **Parameters:** * ​cond ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The bool value to assert. * ​msg ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The message to display on failure. * ​\*extra ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Additional messages to concatenate to msg. `constrained[cond: Bool]()` Asserts that the condition must be true at compile time. The `constrained()` function introduces a compile-time constraint on the enclosing function. If the condition is true at compile time, the constraint has no effect. If the condition is false, compilation fails and a generic message is displayed. This is similar to `static_assert` in C++. It differs from [`debug_assert()`](/mojo/stdlib/builtin/debug_assert/debug_assert), which is a run-time assertion. For an example, see the [first overload](/mojo/stdlib/builtin/constrained/constrained). **Parameters:** * ​cond ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The bool value to assert.
--- ## constrained (Constrained)
Implements compile-time constraints. These are Mojo built-ins, so you don't need to import them. ## Functions * [​`constrained`](/mojo/stdlib/builtin/constrained/constrained): Asserts that the condition must be true at compile time.
--- ## Coroutine
`@register_passable` `struct Coroutine[type: AnyType, origins: OriginSet]` Represents a coroutine. Coroutines can pause execution saving the state of the program (including values of local variables and the location of the next instruction to be executed). When the coroutine is resumed, execution continues from where it left off, with the saved state restored. ## Parameters * ​type ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Type of value returned upon completion of the coroutine. * ​origins ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): The origin of the coroutine's captures. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `@implicit` `__init__(handle: AnyCoroutine) -> Self` Construct a coroutine object from a handle. **Args:** * ​handle ([`AnyCoroutine`](/mojo/stdlib/builtin/coroutine/#anycoroutine)): The init handle. ### `__await__` `__await__(deinit self, out result: type)` Suspends the current coroutine until the coroutine is complete. **Returns:** `type`: The coroutine promise. ### `force_destroy` `force_destroy(deinit self)` Destroy the coroutine object.
--- ## RaisingCoroutine
`@register_passable` `struct RaisingCoroutine[type: AnyType, origins: OriginSet]` Represents a coroutine that can raise. Coroutines can pause execution saving the state of the program (including values of local variables and the location of the next instruction to be executed). When the coroutine is resumed, execution continues from where it left off, with the saved state restored. ## Parameters * ​type ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Type of value returned upon completion of the coroutine. * ​origins ([`OriginSet`](/mojo/stdlib/builtin/type_aliases/#originset)): The origin set of the coroutine's captures. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `@implicit` `__init__(handle: AnyCoroutine) -> Self` Construct a coroutine object from a handle. **Args:** * ​handle ([`AnyCoroutine`](/mojo/stdlib/builtin/coroutine/#anycoroutine)): The init handle. ### `__await__` `__await__(var self, out result: type)` Suspends the current coroutine until the coroutine is complete. **Returns:** `type`: The result value from the completed coroutine. **Raises:** If the coroutine execution encounters an error. ### `force_destroy` `force_destroy(deinit self)` Destroy the coroutine object.
--- ## coroutine (Coroutine)
Implements classes and methods for coroutines. These are Mojo built-ins, so you don't need to import them. ## `comptime` values ### `AnyCoroutine` `comptime AnyCoroutine = __mlir_type.`!co.routine\`\` The MLIR type representing a coroutine handle. ## Structs * [​`Coroutine`](/mojo/stdlib/builtin/coroutine/Coroutine): Represents a coroutine. * [​`RaisingCoroutine`](/mojo/stdlib/builtin/coroutine/RaisingCoroutine): Represents a coroutine that can raise.
--- ## debug_assert
`debug_assert[cond: fn() capturing -> Bool, assert_mode: StringSlice[StaticConstantOrigin] = "none", *Ts: Writable = *?, *, cpu_only: Bool = False](*messages: *Ts)` Asserts that the condition is true at run time. If the condition is false, the assertion displays the given message and causes the program to exit. You can pass in multiple arguments to generate a formatted message. No string allocation occurs unless the assertion is triggered. ```mojo x = 0 debug_assert(x > 0, "expected x to be more than 0 but got: ", x) ``` Normal assertions are off by default—they only run when the program is compiled with all assertions enabled. You can set the `assert_mode` to `safe` to create an assertion that's on by default: ```mojo debug_assert[assert_mode="safe"]( x > 0, "expected x to be more than 0 but got: ", x ) ``` Use the `ASSERT` variable to turn assertions on or off when building or running a Mojo program: ```sh mojo -D ASSERT=all main.mojo ``` The `ASSERT` variable takes the following values: * all: Turn on all assertions. * safe: Turn on "safe" assertions only. This is the default. * none: Turn off all assertions, for performance at the cost of safety. * warn: Turn on all assertions, but print any errors instead of exiting. To ensure that you have no run-time penalty from your assertions even when they're disabled, make sure there are no side effects in your message and condition expressions. For example: ```mojo person = "name: john, age: 50" name = "john" debug_assert(String("name: ", name) in person, "unexpected name") ``` This will have a run-time penalty due to allocating a `String` in the condition expression, even when assertions are disabled. To avoid this, put the condition inside a closure so it runs only when the assertion is turned on: ```mojo fn check_name() capturing -> Bool: return String("name: ", name) in person debug_assert[check_name]("unexpected name") ``` If you need to allocate, and so don't want the assert to ever run on GPU, you can set it to CPU only: ```mojo debug_assert[check_name, cpu_only=True]("unexpected name") ``` For compile-time assertions, see [`constrained()`](/mojo/stdlib/builtin/constrained/constrained). **Parameters:** * ​cond (`fn() capturing -> Bool`): The function to invoke to check if the assertion holds. * ​assert\_mode ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Determines when the assert is turned on. * default ("none"): Turned on when compiled with `-D ASSERT=all`. * "safe": Turned on by default. * ​\*Ts ([`Writable`](/mojo/stdlib/io/write/Writable)): The element types for the message arguments. * ​cpu\_only ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If true, only run the assert on CPU. **Args:** * ​\*messages (`*Ts`): A set of [`Writable`](/mojo/stdlib/io/write/Writable/) arguments to convert to a `String` message. `debug_assert[assert_mode: StringSlice[StaticConstantOrigin] = "none", *Ts: Writable = *?, *, cpu_only: Bool = False, _use_compiler_assume: Bool = False](cond: Bool, *messages: *Ts)` Asserts that the condition is true at run time. If the condition is false, the assertion displays the given message and causes the program to exit. You can pass in multiple arguments to generate a formatted message. No string allocation occurs unless the assertion is triggered. ```mojo x = 0 debug_assert(x > 0, "expected x to be more than 0 but got: ", x) ``` Normal assertions are off by default—they only run when the program is compiled with all assertions enabled. You can set the `assert_mode` to `safe` to create an assertion that's on by default: ```mojo debug_assert[assert_mode="safe"]( x > 0, "expected x to be more than 0 but got: ", x ) ``` Use the `ASSERT` variable to turn assertions on or off when building or running a Mojo program: ```sh mojo -D ASSERT=all main.mojo ``` The `ASSERT` variable takes the following values: * all: Turn on all assertions. * safe: Turn on "safe" assertions only. This is the default. * none: Turn off all assertions, for performance at the cost of safety. * warn: Turn on all assertions, but print any errors instead of exiting. To ensure that you have no run-time penalty from your assertions even when they're disabled, make sure there are no side effects in your message and condition expressions. For example: ```mojo person = "name: john, age: 50" name = "john" debug_assert(String("name: ", name) in person, "unexpected name") ``` This will have a run-time penalty due to allocating a `String` in the condition expression, even when assertions are disabled. To avoid this, put the condition inside a closure so it runs only when the assertion is turned on: ```mojo fn check_name() capturing -> Bool: return String("name: ", name) in person debug_assert[check_name]("unexpected name") ``` If you need to allocate, and so don't want the assert to ever run on GPU, you can set it to CPU only: ```mojo debug_assert[check_name, cpu_only=True]("unexpected name") ``` For compile-time assertions, see [`constrained()`](/mojo/stdlib/builtin/constrained/constrained). **Parameters:** * ​assert\_mode ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Determines when the assert is turned on. * default ("none"): Turned on when compiled with `-D ASSERT=all`. * "safe": Turned on by default. * ​\*Ts ([`Writable`](/mojo/stdlib/io/write/Writable)): The element types for the message arguments. * ​cpu\_only ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If true, only run the assert on CPU. * ​\_use\_compiler\_assume ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If true, assume the condition is true for repeated checks, to help the compiler optimize (default False). **Args:** * ​cond ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The bool value to assert. * ​\*messages (`*Ts`): A set of [`Writable`](/mojo/stdlib/io/write/Writable/) arguments to convert to a `String` message. `debug_assert[assert_mode: StringSlice[StaticConstantOrigin] = "none", cpu_only: Bool = False, _use_compiler_assume: Bool = False](cond: Bool, message: StringLiteral[value])` Asserts that the condition is true at run time. If the condition is false, the assertion displays the given message and causes the program to exit. You can pass in multiple arguments to generate a formatted message. No string allocation occurs unless the assertion is triggered. ```mojo x = 0 debug_assert(x > 0, "expected x to be more than 0 but got: ", x) ``` Normal assertions are off by default—they only run when the program is compiled with all assertions enabled. You can set the `assert_mode` to `safe` to create an assertion that's on by default: ```mojo debug_assert[assert_mode="safe"]( x > 0, "expected x to be more than 0 but got: ", x ) ``` Use the `ASSERT` variable to turn assertions on or off when building or running a Mojo program: ```sh mojo -D ASSERT=all main.mojo ``` The `ASSERT` variable takes the following values: * all: Turn on all assertions. * safe: Turn on "safe" assertions only. This is the default. * none: Turn off all assertions, for performance at the cost of safety. * warn: Turn on all assertions, but print any errors instead of exiting. To ensure that you have no run-time penalty from your assertions even when they're disabled, make sure there are no side effects in your message and condition expressions. For example: ```mojo person = "name: john, age: 50" name = "john" debug_assert(String("name: ", name) in person, "unexpected name") ``` This will have a run-time penalty due to allocating a `String` in the condition expression, even when assertions are disabled. To avoid this, put the condition inside a closure so it runs only when the assertion is turned on: ```mojo fn check_name() capturing -> Bool: return String("name: ", name) in person debug_assert[check_name]("unexpected name") ``` If you need to allocate, and so don't want the assert to ever run on GPU, you can set it to CPU only: ```mojo debug_assert[check_name, cpu_only=True]("unexpected name") ``` For compile-time assertions, see [`constrained()`](/mojo/stdlib/builtin/constrained/constrained). **Parameters:** * ​assert\_mode ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Determines when the assert is turned on. * default ("none"): Turned on when compiled with `-D ASSERT=all`. * "safe": Turned on by default. * ​cpu\_only ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If true, only run the assert on CPU. * ​\_use\_compiler\_assume ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If true, assume the condition is true for repeated checks, to help the compiler optimize (default False). **Args:** * ​cond ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The bool value to assert. * ​message ([`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral)): A static string message.
--- ## debug_assert (Debug_assert)
Implements run-time assertions. These are Mojo built-ins, so you don't need to import them. ## `comptime` values ### `ASSERT_MODE` `comptime ASSERT_MODE = env_get_string["ASSERT", "safe"]()` The compile-time assertion mode from the ASSERT environment variable. ## Functions * [​`debug_assert`](/mojo/stdlib/builtin/debug_assert/debug_assert): Asserts that the condition is true at run time.
--- ## DevicePassable
This trait marks types as passable to accelerator devices. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `device_type` `comptime device_type` Indicate the type being used on accelerator devices. ## Required methods ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). For example, Int would return "Int", DeviceBuffer\[DType.float32] would return "DeviceBuffer\[DType.float32]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. For example, because DeviceBuffer's device\_type is UnsafePointer, DeviceBuffer\[DType.float32]'s get\_device\_type\_name() should return something like "UnsafePointer\[Scalar\[DType.float32]]". This is used for error messages when passing types to the device. TODO: This method will be retired soon when better kernel call error messages arrive. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name.
--- ## device_passable
## Traits * [​`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable): This trait marks types as passable to accelerator devices.
--- ## DType
`@register_passable(trivial)` `struct DType` Represents a data type specification and provides methods for working with it. `DType` defines a set of compile-time constant that specify the precise numeric representation of data in order to prevent runtime errors by catching type mismatches at compile time. It directly maps to CPU/GPU instruction sets, allowing the compiler to generate optimal SIMD and vector operations. `DType` behaves like an enum rather than a typical object. You don't instantiate it, but instead use its compile-time constants (aliases) to declare data types for SIMD vectors, tensors, and other data structures. Key usage patterns: * **Type specification**: Use aliases like `DType.float32` to specify types for SIMD vectors, tensors, and other data structures * **Type parameters**: Pass `DType` values as compile-time parameters to parameterized types like `SIMD[dtype, size]` * **Type introspection**: Call methods like `.is_floating_point()` to query type properties at compile time * **Type conversion**: Use in casting operations to convert between different numeric representations **Note:** Not all data types are supported on all platforms. For example, `DType.bfloat16` is currently not supported on Apple Silicon. Example: ```mojo var data = SIMD[DType.float16, 4](1.5, 2.5, 3.5, 4.5) var dtype = data.dtype print("Is float:", dtype.is_floating_point()) # True print("Is signed:", dtype.is_signed()) # True ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`Identifiable`](/mojo/stdlib/builtin/identifiable/Identifiable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `bfloat16` `comptime bfloat16 = DType.bfloat16` Represents a brain floating point value whose bitwidth is 16. ### `bool` `comptime bool = DType.bool` Represents a boolean data type. ### `float16` `comptime float16 = DType.float16` Represents an IEEE754-2008 `binary16` floating point value. ### `float32` `comptime float32 = DType.float32` Represents an IEEE754-2008 `binary32` floating point value. ### `float4_e2m1fn` `comptime float4_e2m1fn = DType.float4_e2m1fn` Represents a 4-bit `e2m1` floating point format. This type is encoded as `s.ee.m` and defined by the [Open Compute MX Format Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf): * (s)ign: 1 bit * (e)xponent: 2 bits * (m)antissa: 1 bits * exponent\_bias: 1 ### `float64` `comptime float64 = DType.float64` Represents an IEEE754-2008 `binary64` floating point value. ### `float8_e3m4` `comptime float8_e3m4 = DType.float8_e3m4` Represents an 8-bit `e3m4` floating point format. This type is encoded as `s.eee.mmmm`: * (s)ign: 1 bit * (e)xponent: 3 bits * (m)antissa: 4 bits * exponent bias: 3 * nan: {0,1}.111.1111 * fn: finite (no inf or -inf encodings) * -0: 1.000.0000 ### `float8_e4m3fn` `comptime float8_e4m3fn = DType.float8_e4m3fn` Represents the 8-bit `E4M3` floating point format defined in the [OFP8 standard](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1). This type is named differently across libraries and vendors, for example: * Mojo, PyTorch, JAX, and LLVM refer to it as `e4m3fn`. * OCP, NVIDIA CUDA, and AMD ROCm refer to it as `e4m3`. In these contexts, they are all referring to the same finite type specified in the OFP8 standard above, encoded as `s.eeee.mmm`: * (s)ign: 1 bit * (e)xponent: 4 bits * (m)antissa: 3 bits * exponent bias: 7 * nan: {0,1}.1111.111 * fn: finite (no inf or -inf encodings) * -0: 1.0000.000 ### `float8_e4m3fnuz` `comptime float8_e4m3fnuz = DType.float8_e4m3fnuz` Represents an 8-bit `e4m3fnuz` floating point format. See the [format reference](https://arxiv.org/pdf/2206.02915), encoded as `s.eeee.mmm`: * (s)ign: 1 bit * (e)xponent: 4 bits * (m)antissa: 3 bits * exponent bias: 8 * nan: 1.0000.000 * fn: finite (no inf or -inf encodings) * uz: unsigned zero (no -0 encoding) ### `float8_e5m2` `comptime float8_e5m2 = DType.float8_e5m2` Represents the 8-bit `E5M2` floating point format. This type is defined in the [OFP8 standard](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1), encoded as `s.eeeee.mm`: * (s)ign: 1 bit * (e)xponent: 5 bits * (m)antissa: 2 bits * exponent bias: 15 * nan: {0,1}.11111.{01,10,11} * inf: {0,1}.11111.00 * -0: 1.00000.00 ### `float8_e5m2fnuz` `comptime float8_e5m2fnuz = DType.float8_e5m2fnuz` Represents an 8-bit `e5m2fnuz` floating point format. See the [format reference](https://arxiv.org/pdf/2206.02915), encoded as `s.eeeee.mm`: * (s)ign: 1 bit * (e)xponent: 5 bits * (m)antissa: 2 bits * exponent bias: 16 * nan: 1.00000.00 * fn: finite (no inf or -inf encodings) * uz: unsigned zero (no -0 encoding) ### `float8_e8m0fnu` `comptime float8_e8m0fnu = DType.float8_e8m0fnu` Represents the 8-bit `E8M0Fnu` floating point format. This type is defined in the [OFP8 standard](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1), encoded as `eeeeeeee`: * (e)xponent: 8 bits * (m)antissa: 0 bits * exponent bias: 127 * nan: 11111111 * fn: finite (no inf or -inf encodings) * u: no sign or zero value. ### `int` `comptime int = DType.index` Represents an integral type whose bitwidth is the maximum integral value on the system. ### `int128` `comptime int128 = DType.int128` Represents a signed integer type whose bitwidth is 128. ### `int16` `comptime int16 = DType.int16` Represents a signed integer type whose bitwidth is 16. ### `int256` `comptime int256 = DType.int256` Represents a signed integer type whose bitwidth is 256. ### `int32` `comptime int32 = DType.int32` Represents a signed integer type whose bitwidth is 32. ### `int64` `comptime int64 = DType.int64` Represents a signed integer type whose bitwidth is 64. ### `int8` `comptime int8 = DType.int8` Represents a signed integer type whose bitwidth is 8. ### `invalid` `comptime invalid = DType.invalid` Represents an invalid or unknown data type. ### `uint` `comptime uint = DType.uindex` Represents an unsigned integral type whose bitwidth is the maximum unsigned integral value on the system. ### `uint128` `comptime uint128 = DType.uint128` Represents an unsigned integer type whose bitwidth is 128. ### `uint16` `comptime uint16 = DType.uint16` Represents an unsigned integer type whose bitwidth is 16. ### `uint256` `comptime uint256 = DType.uint256` Represents an unsigned integer type whose bitwidth is 256. ### `uint32` `comptime uint32 = DType.uint32` Represents an unsigned integer type whose bitwidth is 32. ### `uint64` `comptime uint64 = DType.uint64` Represents an unsigned integer type whose bitwidth is 64. ### `uint8` `comptime uint8 = DType.uint8` Represents an unsigned integer type whose bitwidth is 8. ## Methods ### `__init__` `__init__(*, mlir_value: __mlir_type.`!kgen.dtype`) -> Self` Construct a DType from MLIR dtype. **Args:** * ​mlir\_value (`__mlir_type.`!kgen.dtype\`\`): The MLIR dtype. ### `__eq__` `__eq__(self, rhs: Self) -> Bool` Compares one DType to another for equality. **Args:** * ​rhs (`Self`): The DType to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the DTypes are the same and False otherwise. ### `__ne__` `__ne__(self, rhs: Self) -> Bool` Compares one DType to another for inequality. **Args:** * ​rhs (`Self`): The DType to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): False if the DTypes are the same and True otherwise. ### `__is__` `__is__(self, rhs: Self) -> Bool` Compares one DType to another for equality. **Args:** * ​rhs (`Self`): The DType to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the DTypes are the same and False otherwise. ### `__isnot__` `__isnot__(self, rhs: Self) -> Bool` Compares one DType to another for inequality. **Args:** * ​rhs (`Self`): The DType to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the DTypes are the different and False otherwise. ### `__str__` `__str__(self) -> String` Gets the name of the DType. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The name of the dtype. ### `write_to` `write_to(self, mut writer: T)` Formats this dtype to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `__repr__` `__repr__(self) -> String` Gets the representation of the DType e.g. `"DType.float32"`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The representation of the dtype. ### `get_value` `get_value(self) -> __mlir_type.`!kgen.dtype\`\` Gets the associated internal kgen.dtype value. **Returns:** `__mlir_type.`!kgen.dtype\`\`: The kgen.dtype value. ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with this `DType` value. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance. ### `is_unsigned` `is_unsigned(self) -> Bool` Returns True if the type parameter is unsigned and False otherwise. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Returns True if the input type parameter is unsigned. ### `is_signed` `is_signed(self) -> Bool` Returns True if the type parameter is signed and False otherwise. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Returns True if the input type parameter is signed. ### `is_integral` `is_integral(self) -> Bool` Returns True if the type parameter is an integer and False otherwise. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Returns True if the input type parameter is an integer. ### `is_floating_point` `is_floating_point(self) -> Bool` Returns True if the type parameter is a floating-point and False otherwise. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Returns True if the input type parameter is a floating-point. ### `is_float8` `is_float8(self) -> Bool` Returns True if the dtype is a 8bit-precision floating point type, e.g. float8\_e5m2, float8\_e5m2fnuz, float8\_e4m3fn and float8\_e4m3fnuz. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the dtype is a 8bit-precision float, false otherwise. ### `is_half_float` `is_half_float(self) -> Bool` Returns True if the dtype is a half-precision floating point type, e.g. either fp16 or bf16. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the dtype is a half-precision float, false otherwise.. ### `is_numeric` `is_numeric(self) -> Bool` Returns True if the type parameter is numeric (i.e. you can perform arithmetic operations on). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Returns True if the input type parameter is either integral or floating-point. ### `mantissa_width` `static mantissa_width[dtype: Self]() -> Int` Returns the mantissa width of a floating point type. **Parameters:** * ​dtype (`Self`): The DType. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The mantissa width. ### `max_exponent` `static max_exponent[dtype: Self]() -> Int` Returns the max exponent of a floating point dtype without accounting for inf representations. This is not the maximum representable exponent, which is generally equal to the exponent\_bias. **Parameters:** * ​dtype (`Self`): The DType. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The max exponent. ### `exponent_width` `static exponent_width[dtype: Self]() -> Int` Returns the exponent width of a floating point type. **Parameters:** * ​dtype (`Self`): The DType. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The exponent width. ### `exponent_bias` `static exponent_bias[dtype: Self]() -> Int` Returns the exponent bias of a floating point type. **Parameters:** * ​dtype (`Self`): The DType. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The exponent bias. ### `__mlir_type` `__mlir_type(self) -> __mlir_type.`!kgen.deferred\`\` Returns the MLIR type of the current DType as an MLIR type. **Returns:** `__mlir_type.`!kgen.deferred\`\`: The MLIR type of the current DType. ### `get_dtype` `static get_dtype[T: AnyType, size: Int = 1]() -> Self` Get the `DType` if the given Type is a `SIMD[_, size]` of a `DType`. **Parameters:** * ​T ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): AnyType. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The SIMD size to compare against. **Returns:** `Self`: The `DType` if matched, otherwise `DType.invalid`. ### `is_scalar` `static is_scalar[T: AnyType]() -> Bool` Whether the given Type is a Scalar of a DType. **Parameters:** * ​T ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): AnyType. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): The result.
--- ## dtype (Dtype)
Implements the DType class. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`DType`](/mojo/stdlib/builtin/dtype/DType): Represents a data type specification and provides methods for working with it.
--- ## Error
`struct Error` This type represents an Error. ## Fields * ​data (`String`): The message of the error. * ​stack\_trace (`StackTrace`): The stack trace of the error. By default the stack trace is not collected for the Error, unless user sets the stack\_trace\_depth parameter to value >= 0. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `@implicit` `__init__(out self, var value: String, *, depth: Int = -1)` Construct an Error object with a given String. **Args:** * ​value ([`String`](/mojo/stdlib/collections/string/string/String)): The error message. * ​depth ([`Int`](/mojo/stdlib/builtin/int/Int)): The depth of the stack trace to collect. `__init__(out self)` Default constructor. `@implicit` `__init__(out self, value: StringLiteral[value])` Construct an Error object with a given string literal. **Args:** * ​value ([`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral)): The error message. `@implicit` `__init__(out self, arg: T)` Construct an Error from a Writable argument. **Args:** * ​arg (`T`): A Writable argument. `__init__[*Ts: Writable](out self, *args: *Ts)` Construct an Error by concatenating a sequence of Writable arguments. **Parameters:** * ​\*Ts ([`Writable`](/mojo/stdlib/io/write/Writable)): The types of the arguments to format. Each type must be satisfy `Writable`. **Args:** * ​\*args (`*Ts`): A sequence of Writable arguments. ### `__bool__` `__bool__(self) -> Bool` Returns True if the error is set and false otherwise. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the error object contains a value and False otherwise. ### `__str__` `__str__(self) -> String` Converts the Error to string representation. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A String of the error message. ### `write_to` `write_to(self, mut writer: T)` Formats this error to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `__repr__` `__repr__(self) -> String` Converts the Error to printable representation. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A printable representation of the error message. ### `get_stack_trace` `get_stack_trace(self) -> StackTrace` Returns the stack trace of the error. **Returns:** `StackTrace`: The stringable stack trace of the error.
--- ## StackTrace
`@register_passable` `struct StackTrace` Holds a stack trace of a location when StackTrace is constructed. ## Fields * ​value (`ArcPointer[OwnedPointer[UInt8]]`): A reference counting pointer to a char array containing the stack trace. Note: This owned pointer *can be null*. We'd use Optional\[OwnedPointer] but we don't have good niche optimization and Optional\[T] requires T: Copyable ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` Construct an empty stack trace. `__init__(*, depth: Int) -> Self` Construct a new stack trace. **Args:** * ​depth ([`Int`](/mojo/stdlib/builtin/int/Int)): The depth of the stack trace. When `depth` is zero, entire stack trace is collected. When `depth` is negative, no stack trace is collected. ### `__str__` `__str__(self) -> String` Converts the StackTrace to string representation. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A String of the stack trace.
--- ## error (Error)
Implements the Error class. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`Error`](/mojo/stdlib/builtin/error/Error): This type represents an Error. * [​`StackTrace`](/mojo/stdlib/builtin/error/StackTrace): Holds a stack trace of a location when StackTrace is constructed.
--- ## FloatLiteral
`@register_passable(trivial)` `struct FloatLiteral[value: __mlir_type.`!pop.float\_literal`]` Mojo floating point literal type. ## Parameters * ​value (`__mlir_type.`!pop.float\_literal\`\`): The underlying infinite precision floating point value. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Floatable`](/mojo/stdlib/builtin/floatable/Floatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `infinity` `comptime infinity = inf` Positive infinity value. ### `nan` `comptime nan` Not a number (NaN) value. ### `negative_infinity` `comptime negative_infinity = -inf` Negative infinity value. ### `negative_zero` `comptime negative_zero = -0.0` Negative zero value. ## Methods ### `__init__` `__init__() -> Self` Create a FloatLiteral for any parameter value. `@implicit` `__init__(_value: IntLiteral[value]) -> FloatLiteral[#pop.int_to_float_literal<*"value`2x">]\` Convert an IntLiteral to a FloatLiteral value. **Args:** * ​\_value ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The IntLiteral value. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral) ### `__bool__` `__bool__(self) -> Bool` A FloatLiteral value is true if it is non-zero. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if non-zero. ### `__neg__` `__neg__(self) -> FloatLiteral[#pop.float_literal_bin< mul value, #pop.float_literal<-1|1>>]` Return the negation of the FloatLiteral value. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The negated FloatLiteral value. ### `__lt__` `__lt__(self, rhs: FloatLiteral[value]) -> Bool` Less than comparison. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this value is less than `rhs`. ### `__le__` `__le__(self, rhs: FloatLiteral[value]) -> Bool` Less than or equal to comparison. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this value is less than or equal to `rhs`. ### `__eq__` `__eq__(self, rhs: FloatLiteral[value]) -> Bool` Compare for equality. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if they are equal. ### `__ne__` `__ne__(self, rhs: FloatLiteral[value]) -> Bool` Compare for inequality. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if they are not equal. ### `__gt__` `__gt__(self, rhs: FloatLiteral[value]) -> Bool` Greater than comparison. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this value is greater than `rhs`. ### `__ge__` `__ge__(self, rhs: FloatLiteral[value]) -> Bool` Greater than or equal to comparison. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this value is greater than or equal to `rhs`. ### `__add__` `__add__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< add value, *"value`2x">]\` Add two FloatLiterals. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to add. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The sum of the two values. ### `__sub__` `__sub__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< sub value, *"value`2x">]\` Subtract two FloatLiterals. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to subtract. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The difference of the two values. ### `__mul__` `__mul__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< mul value, *"value`2x">]\` Multiply two FloatLiterals. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to multiply. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The product of the two values. ### `__truediv__` `__truediv__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< truediv value, *"value`2x">]\` Divide two FloatLiterals. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to divide. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The quotient of the two values. ### `__floordiv__` `__floordiv__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< floordiv value, *"value`2x">]\` Returns self divided by rhs, rounded down to the nearest integer. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The divisor value. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): `floor(self / rhs)` value. ### `__mod__` `__mod__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< sub value, #pop.float_literal_bin< mul #pop.float_literal_bin< floordiv value, *"value`2x">, \*"value`2x">>]` Return the remainder of self divided by rhs. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to divide on. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The remainder of dividing self by rhs. ### `__radd__` `__radd__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< add *"value`2x", value>]\` Reversed addition operator. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to add. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The sum of this and the given value. ### `__rsub__` `__rsub__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< sub *"value`2x", value>]\` Reversed subtraction operator. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to subtract from. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The result of subtracting this from the given value. ### `__rmul__` `__rmul__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< mul *"value`2x", value>]\` Reversed multiplication operator. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to multiply. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The product of the given number and this. ### `__rtruediv__` `__rtruediv__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< truediv *"value`2x", value>]\` Reversed division. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to be divided by this. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The result of dividing the given value by this. ### `__rfloordiv__` `__rfloordiv__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< floordiv *"value`2x", value>]\` Returns rhs divided by self, rounded down to the nearest integer. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to be divided by self. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): `floor(rhs / self)` value. ### `__rmod__` `__rmod__(self, rhs: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< sub *"value`2x", #pop.float\_literal\_bin< mul #pop.float\_literal\_bin< floordiv \*"value`2x", value>, value>>]` Return the remainder of rhs divided by self. **Args:** * ​rhs ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The value to divide on. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The remainder of dividing rhs by self. ### `is_nan` `is_nan(self) -> Bool` Return whether the FloatLiteral is nan. Since `nan == nan` is False, this provides a way to check for nan-ness. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True, if the value is nan, False otherwise. ### `is_neg_zero` `is_neg_zero(self) -> Bool` Return whether the FloatLiteral is negative zero. Since `FloatLiteral.negative_zero == 0.0` is True, this provides a way to check if the FloatLiteral is negative zero. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True, if the value is negative zero, False otherwise. ### `__str__` `__str__(self) -> String` Get the float as a string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation. ### `__int_literal__` `__int_literal__(self) -> IntLiteral[#pop.float_to_int_literal]` Casts the floating point value to an IntLiteral. If there is a fractional component, then the value is truncated towards zero. Eg. `(4.5).__int_literal__()` returns `4`, and `(-3.7).__int_literal__()` returns `-3`. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): The value as an integer. ### `__int__` `__int__(self) -> Int` Converts the FloatLiteral value to an Int. If there is a fractional component, then the value is truncated towards zero. Eg. `(4.5).__int__()` returns `4`, and `(-3.7).__int__()` returns `-3`. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value as an integer. ### `__float__` `__float__(self) -> Float64` Converts the FloatLiteral to a concrete Float64. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The Float value. ### `__ceildiv__` `__ceildiv__(self, denominator: FloatLiteral[value]) -> FloatLiteral[#pop.float_literal_bin< mul #pop.float_literal_bin< floordiv value, #pop.float_literal_bin< mul *"value`2x", #pop.float\_literal<-1|1>>>, #pop.float\_literal<-1|1>>]\` Return the rounded-up result of dividing self by denominator. **Args:** * ​denominator ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The denominator. **Returns:** [`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): The ceiling of dividing numerator by denominator.
--- ## float_literal
Implements the FloatLiteral class. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral): Mojo floating point literal type.
--- ## Floatable
The `Floatable` trait describes a type that can be converted to a Float64. This trait requires the type to implement the `__float__()` method. For example: ```mojo struct Foo(Floatable): var i: Float64 fn __float__(self) -> Float64: return self.i ``` A `Foo` can now be converted to a `Float64`: ```mojo var f = Float64(Foo(5.5)) ``` **Note:** If the `__float__()` method can raise an error, use the [`FloatableRaising`](/mojo/stdlib/builtin/floatable/floatableraising) trait instead. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__float__` `__float__(self: _Self) -> Float64` Get the float point representation of the value. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The float point representation of the value.
--- ## FloatableRaising
The `FloatableRaising` trait describes a type that can be converted to a Float64, but the conversion might raise an error (e.g.: a string). This trait requires the type to implement the `__float__()` method, which can raise an error. For example: ```mojo from utils import Variant struct MaybeFloat(FloatableRaising): var value: Variant[Float64, NoneType] fn __float__(self) raises -> Float64: if self.value.isa[NoneType](): raise "Float expected" return self.value[Float64] ``` A `MaybeFloat` can now be converted to `Float64`: ```mojo try: print(Float64(MaybeFloat(4.6))) except: print("error occurred") ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__float__` `__float__(self: _Self) -> Float64` Get the float point representation of the value. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The float point representation of the value. **Raises:** If the type does not have a float point representation.
--- ## floatable (Floatable)
Implements the `Floatable` and `FloatableRaising` traits. These are Mojo built-ins, so you don't need to import them. ## Traits * [​`Floatable`](/mojo/stdlib/builtin/floatable/Floatable): The `Floatable` trait describes a type that can be converted to a Float64. * [​`FloatableRaising`](/mojo/stdlib/builtin/floatable/FloatableRaising): The `FloatableRaising` trait describes a type that can be converted to a Float64, but the conversion might raise an error (e.g.: a string).
--- ## bin
`bin(num: Scalar[dtype], /, *, prefix: StringSlice[StaticConstantOrigin] = "0b") -> String` Return the binary string representation an integral value. ```mojo print(bin(123)) print(bin(-123)) ``` ```plaintext '0b1111011' '-0b1111011' ``` **Args:** * ​num ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): An integral scalar value. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The binary string representation of num. `bin(b: Scalar[DType.bool], /, *, prefix: StringSlice[StaticConstantOrigin] = "0b") -> String` Returns the binary representation of a scalar bool. **Args:** * ​b ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): A scalar bool value. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The binary string representation of b. `bin[T: Intable, //](num: T, /, *, prefix: StringSlice[StaticConstantOrigin] = "0b") -> String` Returns the binary representation of an indexer type. **Parameters:** * ​T ([`Intable`](/mojo/stdlib/builtin/int/Intable)): The Intable type. **Args:** * ​num (`T`): An indexer value. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The binary string representation of num.
--- ## hex
`hex(value: Scalar[dtype], /, *, prefix: StringSlice[StaticConstantOrigin] = "0x") -> String` Returns the hex string representation of the given integer. The hexadecimal representation is a base-16 encoding of the integer value. The returned string will be prefixed with "0x" to indicate that the subsequent digits are hex. **Args:** * ​value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The integer value to format. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the hex representation of the given integer. `hex[T: Intable, //](value: T, /, *, prefix: StringSlice[StaticConstantOrigin] = "0x") -> String` Returns the hex string representation of the given integer. The hexadecimal representation is a base-16 encoding of the integer value. The returned string will be prefixed with "0x" to indicate that the subsequent digits are hex. **Parameters:** * ​T ([`Intable`](/mojo/stdlib/builtin/int/Intable)): The indexer type to represent in hexadecimal. **Args:** * ​value (`T`): The integer value to format. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the hex representation of the given integer. `hex(value: Scalar[DType.bool], /, *, prefix: StringSlice[StaticConstantOrigin] = "0x") -> String` Returns the hex string representation of the given scalar bool. The hexadecimal representation is a base-16 encoding of the bool. The returned string will be prefixed with "0x" to indicate that the subsequent digits are hex. **Args:** * ​value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The bool value to format. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the hex representation of the given bool.
--- ## format_int
Provides the `hex` and `bin` functions. These are Mojo built-ins, so you don't need to import them. ## Functions * [​`bin`](/mojo/stdlib/builtin/format_int/bin): Return the binary string representation an integral value. * [​`hex`](/mojo/stdlib/builtin/format_int/hex): Returns the hex string representation of the given integer. * [​`oct`](/mojo/stdlib/builtin/format_int/oct): Returns the octal string representation of the given integer.
--- ## oct
`oct(value: Scalar[dtype], /, *, prefix: StringSlice[StaticConstantOrigin] = "0o") -> String` Returns the octal string representation of the given integer. The octal representation is a base-8 encoding of the integer value. The returned string will be prefixed with "0o" to indicate that the subsequent digits are octal. **Args:** * ​value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The integer value to format. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the octal representation of the given integer. `oct[T: Intable, //](value: T, /, *, prefix: StringSlice[StaticConstantOrigin] = "0o") -> String` Returns the octal string representation of the given integer. The octal representation is a base-8 encoding of the integer value. The returned string will be prefixed with "0o" to indicate that the subsequent digits are octal. **Parameters:** * ​T ([`Intable`](/mojo/stdlib/builtin/int/Intable)): The intable type to represent in octal. **Args:** * ​value (`T`): The integer value to format. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the octal representation of the given integer. `oct(value: Scalar[DType.bool], /, *, prefix: StringSlice[StaticConstantOrigin] = "0o") -> String` Returns the octal string representation of the given scalar bool. The octal representation is a base-8 encoding of the bool. The returned string will be prefixed with "0o" to indicate that the subsequent digits are octal. **Args:** * ​value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The bool value to format. * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix of the formatted int. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the octal representation of the given bool.
--- ## global_constant
`global_constant[T: AnyType, //, value: T]() -> ref [StaticConstantOrigin] T` Creates a reference to a compile-time constant value. This function uses the MLIR `pop.global_constant` operation to create a reference to a compile-time value without materializing the entire value at runtime. This is particularly useful for large lookup tables where you want to avoid materializing the entire table when accessing individual elements. Examples: ```mojo from builtin.globals import global_constant # Create a reference to a constant array and access elements comptime lookup_table = InlineArray[Int, 4](1, 2, 3, 4) var element = global_constant[lookup_table]()[2] # Access without materializing entire array print(element) # Prints: 3 # Use with more complex compile-time values fn compute(x: Int) -> Int: return x * 2 + 1 comptime data = InlineArray[Int, 3](1, compute(5), 100) ref data_ref = global_constant[data]() print(data_ref[0], data_ref[1], data_ref[2]) # Prints: 1 11 100 ``` **Parameters:** * ​T ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The type of the constant value. * ​value (`T`): The compile-time constant value. **Returns:** `ref`: A reference to the global constant.
--- ## globals
Utilities for working with global constants. This module provides helper functions for efficiently creating references to compile-time constants without materializing entire data structures in memory. ## Functions * [​`global_constant`](/mojo/stdlib/builtin/globals/global_constant): Creates a reference to a compile-time constant value.
--- ## Identifiable
The Identifiable trait denotes a type with an identity which can be compared with other instances of itself. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__is__` `__is__(self: _Self, rhs: _Self) -> Bool` Define whether `self` has the same identity as `rhs`. **Args:** * ​rhs (`_Self`): The right hand side of the comparison. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is `rhs`. ## Provided methods ### `__isnot__` `__isnot__(self: _Self, rhs: _Self) -> Bool` Define whether `self` has a different identity than `rhs`. **Args:** * ​rhs (`_Self`): The right hand side of the comparison. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is not `rhs`.
--- ## identifiable (Identifiable)
## Traits * [​`Identifiable`](/mojo/stdlib/builtin/identifiable/Identifiable): The Identifiable trait denotes a type with an identity which can be compared with other instances of itself.
--- ## builtin
Implements the builtin package. ## Modules * [​`anytype`](/mojo/stdlib/builtin/anytype/): Defines the core traits for object lifetime management in Mojo. * [​`bool`](/mojo/stdlib/builtin/bool/): Implements the Bool class. * [​`breakpoint`](/mojo/stdlib/builtin/breakpoint/): This module includes the builtin breakpoint function. * [​`builtin_slice`](/mojo/stdlib/builtin/builtin_slice/): Implements slice. * [​`comparable`](/mojo/stdlib/builtin/comparable/): * [​`constrained`](/mojo/stdlib/builtin/constrained/): Implements compile-time constraints. * [​`coroutine`](/mojo/stdlib/builtin/coroutine/): Implements classes and methods for coroutines. * [​`debug_assert`](/mojo/stdlib/builtin/debug_assert/): Implements run-time assertions. * [​`device_passable`](/mojo/stdlib/builtin/device_passable/): * [​`dtype`](/mojo/stdlib/builtin/dtype/): Implements the DType class. * [​`error`](/mojo/stdlib/builtin/error/): Implements the Error class. * [​`float_literal`](/mojo/stdlib/builtin/float_literal/): Implements the FloatLiteral class. * [​`floatable`](/mojo/stdlib/builtin/floatable/): Implements the `Floatable` and `FloatableRaising` traits. * [​`format_int`](/mojo/stdlib/builtin/format_int/): Provides the `hex` and `bin` functions. * [​`globals`](/mojo/stdlib/builtin/globals/): Utilities for working with global constants. * [​`identifiable`](/mojo/stdlib/builtin/identifiable/): * [​`int`](/mojo/stdlib/builtin/int/): Implements the Int class. * [​`int_literal`](/mojo/stdlib/builtin/int_literal/): Implements the IntLiteral class. * [​`len`](/mojo/stdlib/builtin/len/): Provides the `len()` function and its associated traits. * [​`math`](/mojo/stdlib/builtin/math/): Defines basic math functions for use in the open source parts of the standard library since the `math` package is currently closed source and cannot be depended on in the open source parts of the standard library. * [​`none`](/mojo/stdlib/builtin/none/): Defines the builtin `NoneType`. * [​`range`](/mojo/stdlib/builtin/range/): Implements a 'range' call. * [​`rebind`](/mojo/stdlib/builtin/rebind/): Implements type rebind/trait downcast * [​`repr`](/mojo/stdlib/builtin/repr/): Provide the `repr` function. * [​`reversed`](/mojo/stdlib/builtin/reversed/): Provides the `reversed` function for reverse iteration over collections. * [​`simd`](/mojo/stdlib/builtin/simd/): Implements SIMD primitives and abstractions. * [​`sort`](/mojo/stdlib/builtin/sort/): Implements the built-in `sort` function. * [​`str`](/mojo/stdlib/builtin/str/): Provides the `Stringable` and `StringableRaising` traits. * [​`string_literal`](/mojo/stdlib/builtin/string_literal/): Implements the StringLiteral struct. * [​`swap`](/mojo/stdlib/builtin/swap/): Implements the built-in `swap` function. * [​`tuple`](/mojo/stdlib/builtin/tuple/): Implements the Tuple type. * [​`type_aliases`](/mojo/stdlib/builtin/type_aliases/): Defines some type aliases. * [​`uint`](/mojo/stdlib/builtin/uint/): Implements the UInt class. * [​`value`](/mojo/stdlib/builtin/value/): Defines core value traits. * [​`variadics`](/mojo/stdlib/builtin/variadics/): Implements the VariadicList and VariadicPack types.
--- ## Indexer
The `Indexer` trait is used for types that can index into a collection or pointer. The type returned is the underlying \_\_mlir\_type.index, enabling types like `UInt` to not have to be converted to an `Int` first. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__mlir_index__` `__mlir_index__(self: _Self) -> __mlir_type.index` Convert to index. **Returns:** `__mlir_type.index`: The corresponding \_\_mlir\_type.index value.
--- ## Int
`@register_passable(trivial)` `struct Int` This type represents an integer value. ## Implemented traits [`Absable`](/mojo/stdlib/builtin/math/Absable), [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`CeilDivable`](/mojo/stdlib/math/math/CeilDivable), [`Ceilable`](/mojo/stdlib/math/math/Ceilable), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`ConvertibleFromPython`](/mojo/stdlib/python/conversions/ConvertibleFromPython), [`ConvertibleToPython`](/mojo/stdlib/python/conversions/ConvertibleToPython), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`DivModable`](/mojo/stdlib/builtin/math/DivModable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Floorable`](/mojo/stdlib/math/math/Floorable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Indexer`](/mojo/stdlib/builtin/int/Indexer), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`IntervalElement`](/mojo/stdlib/collections/interval/IntervalElement), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Powable`](/mojo/stdlib/builtin/math/Powable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Roundable`](/mojo/stdlib/builtin/math/Roundable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`Truncable`](/mojo/stdlib/math/math/Truncable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BITWIDTH` `comptime BITWIDTH = Int.__init__[Int](bit_width_of[DType.index]())` The bit width of the integer type. ### `device_type` `comptime device_type = Int` Int is remapped to the same type when passed to accelerator devices. ### `MAX` `comptime MAX = Int.__init__[Scalar[DType.index]](Scalar[DType.index].MAX)` Returns the maximum integer value. ### `MIN` `comptime MIN = Int.__init__[Scalar[DType.index]](Scalar[DType.index].MIN)` Returns the minimum value of type. ## Methods ### `__init__` `__init__() -> Self` Default constructor that produces zero. `@implicit` `__init__(value: IntLiteral[value]) -> Self` Construct Int from the given IntLiteral value. **Args:** * ​value ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The init value. `__init__(value: UInt) -> Self` Construct Int from the given UInt value. **Args:** * ​value ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The init value. `__init__[T: Intable](value: T) -> Self` Get the Int representation of the value. **Parameters:** * ​T ([`Intable`](/mojo/stdlib/builtin/int/Intable)): The Intable type. **Args:** * ​value (`T`): The object to get the integral representation of. `__init__[T: IntableRaising](out self, value: T)` Get the Int representation of the value. **Parameters:** * ​T ([`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising)): The Intable type. **Args:** * ​value (`T`): The object to get the integral representation of. **Raises:** If the type does not have an integral representation. ### `__bool__` `__bool__(self) -> Bool` Convert this Int to Bool. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): False Bool value if the value is equal to 0 and True otherwise. ### `__neg__` `__neg__(self) -> Self` Return -self. **Returns:** `Self`: The -self value. ### `__pos__` `__pos__(self) -> Self` Return +self. **Returns:** `Self`: The +self value. ### `__invert__` `__invert__(self) -> Self` Return \~self. **Returns:** `Self`: The \~self value. ### `__lt__` `__lt__(self, rhs: Self) -> Bool` Compare this Int to the RHS using LT comparison. **Args:** * ​rhs (`Self`): The other Int to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this Int is less-than the RHS Int and False otherwise. ### `__le__` `__le__(self, rhs: Self) -> Bool` Compare this Int to the RHS using LE comparison. **Args:** * ​rhs (`Self`): The other Int to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this Int is less-or-equal than the RHS Int and False otherwise. ### `__eq__` `__eq__(self, rhs: Self) -> Bool` Compare this Int to the RHS using EQ comparison. **Args:** * ​rhs (`Self`): The other Int to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this Int is equal to the RHS Int and False otherwise. ### `__ne__` `__ne__(self, rhs: Self) -> Bool` Compare this Int to the RHS using NE comparison. **Args:** * ​rhs (`Self`): The other Int to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this Int is non-equal to the RHS Int and False otherwise. ### `__gt__` `__gt__(self, rhs: Self) -> Bool` Compare this Int to the RHS using GT comparison. **Args:** * ​rhs (`Self`): The other Int to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this Int is greater than the RHS Int and False otherwise. ### `__ge__` `__ge__(self, rhs: Self) -> Bool` Compare this Int to the RHS using GE comparison. **Args:** * ​rhs (`Self`): The other Int to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this Int is greater-or-equal than the RHS Int and False otherwise. ### `__add__` `__add__(self, rhs: Self) -> Self` Return `self + rhs`. **Args:** * ​rhs (`Self`): The value to add. **Returns:** `Self`: `self + rhs` value. ### `__sub__` `__sub__(self, rhs: Self) -> Self` Return `self - rhs`. **Args:** * ​rhs (`Self`): The value to subtract. **Returns:** `Self`: `self - rhs` value. ### `__mul__` `__mul__(self, rhs: Self) -> Self` Return `self * rhs`. **Args:** * ​rhs (`Self`): The value to multiply with. **Returns:** `Self`: `self * rhs` value. ### `__truediv__` `__truediv__(self, rhs: Self) -> Float64` Return the floating point division of `self` and `rhs`. **Args:** * ​rhs (`Self`): The value to divide on. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): `Float64(self)/Float64(rhs)` value. ### `__floordiv__` `__floordiv__(self, rhs: Self) -> Self` Return the division of `self` and `rhs` rounded down to the nearest integer. **Args:** * ​rhs (`Self`): The value to divide on. **Returns:** `Self`: `floor(self/rhs)` value. ### `__mod__` `__mod__(self, rhs: Self) -> Self` Return the remainder of self divided by rhs. **Args:** * ​rhs (`Self`): The value to divide on. **Returns:** `Self`: The remainder of dividing self by rhs. ### `__pow__` `__pow__(self, exp: Self) -> Self` Return the value raised to the power of the given exponent. Computes the power of an integer using the Russian Peasant Method. **Args:** * ​exp (`Self`): The exponent value. **Returns:** `Self`: The value of `self` raised to the power of `exp`. ### `__lshift__` `__lshift__(self, rhs: Self) -> Self` Return `self << rhs`. **Args:** * ​rhs (`Self`): The value to shift with. **Returns:** `Self`: `self << rhs`. ### `__rshift__` `__rshift__(self, rhs: Self) -> Self` Return `self >> rhs`. **Args:** * ​rhs (`Self`): The value to shift with. **Returns:** `Self`: `self >> rhs`. ### `__and__` `__and__(self, rhs: Self) -> Self` Return `self & rhs`. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self & rhs`. ### `__or__` `__or__(self, rhs: Self) -> Self` Return `self | rhs`. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self | rhs`. ### `__xor__` `__xor__(self, rhs: Self) -> Self` Return `self ^ rhs`. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self ^ rhs`. ### `__radd__` `__radd__(self, value: Self) -> Self` Return `value + self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value + self`. ### `__rsub__` `__rsub__(self, value: Self) -> Self` Return `value - self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value - self`. ### `__rmul__` `__rmul__(self, value: Self) -> Self` Return `value * self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value * self`. ### `__rfloordiv__` `__rfloordiv__(self, value: Self) -> Self` Return `value // self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value // self`. ### `__rmod__` `__rmod__(self, value: Self) -> Self` Return `value % self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value % self`. ### `__rpow__` `__rpow__(self, value: Self) -> Self` Return `pow(value,self)`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `pow(value,self)`. ### `__rlshift__` `__rlshift__(self, value: Self) -> Self` Return `value << self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value << self`. ### `__rrshift__` `__rrshift__(self, value: Self) -> Self` Return `value >> self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value >> self`. ### `__rand__` `__rand__(self, value: Self) -> Self` Return `value & self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value & self`. ### `__ror__` `__ror__(self, value: Self) -> Self` Return `value | self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value | self`. ### `__rxor__` `__rxor__(self, value: Self) -> Self` Return `value ^ self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value ^ self`. ### `__iadd__` `__iadd__(mut self, rhs: Self)` Compute `self + rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__isub__` `__isub__(mut self, rhs: Self)` Compute `self - rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__imul__` `__imul__(mut self, rhs: Self)` Compute self\*rhs and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__itruediv__` `__itruediv__(mut self, rhs: Self)` Compute `self / rhs`, convert to int, and save the result in self. Since `floor(self / rhs)` is equivalent to `self // rhs`, this yields the same as `__ifloordiv__`. **Args:** * ​rhs (`Self`): The RHS value. ### `__ifloordiv__` `__ifloordiv__(mut self, rhs: Self)` Compute `self // rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__imod__` `__imod__(mut self, rhs: Self)` Compute `self % rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__ipow__` `__ipow__(mut self, rhs: Self)` Compute `pow(self, rhs)` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__ilshift__` `__ilshift__(mut self, rhs: Self)` Compute `self << rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__irshift__` `__irshift__(mut self, rhs: Self)` Compute `self >> rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__iand__` `__iand__(mut self, rhs: Self)` Compute `self & rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__ixor__` `__ixor__(mut self, rhs: Self)` Compute `self ^ rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__ior__` `__ior__(mut self, rhs: Self)` Compute self|rhs and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. TODO: This will go away soon, when we get better error messages for kernel calls. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): This type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. TODO: This will go away soon, when we get better error messages for kernel calls. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): This type's name. ### `__divmod__` `__divmod__(self, rhs: Self) -> Tuple[Int, Int]` Computes both the quotient and remainder using integer division. **Args:** * ​rhs (`Self`): The value to divide on. **Returns:** `Tuple`: The quotient and remainder as a tuple `(self // rhs, self % rhs)`. ### `__mlir_index__` `__mlir_index__(self) -> __mlir_type.index` Convert to index. **Returns:** `__mlir_type.index`: The corresponding \_\_mlir\_type.index value. ### `__int__` `__int__(self) -> Self` Gets the integral value (this is an identity function for Int). **Returns:** `Self`: The value as an integer. ### `__abs__` `__abs__(self) -> Self` Return the absolute value of the Int value. **Returns:** `Self`: The absolute value. ### `__ceil__` `__ceil__(self) -> Self` Return the ceiling of the Int value, which is itself. **Returns:** `Self`: The Int value itself. ### `__floor__` `__floor__(self) -> Self` Return the floor of the Int value, which is itself. **Returns:** `Self`: The Int value itself. ### `__round__` `__round__(self) -> Self` Return the rounded value of the Int value, which is itself. **Returns:** `Self`: The Int value itself. `__round__(self, ndigits: Self) -> Self` Return the rounded value of the Int value, which is itself. **Args:** * ​ndigits (`Self`): The number of digits to round to. **Returns:** `Self`: The Int value itself if ndigits >= 0 else the rounded value. ### `__trunc__` `__trunc__(self) -> Self` Return the truncated Int value, which is itself. **Returns:** `Self`: The Int value itself. ### `__ceildiv__` `__ceildiv__(self, denominator: Self) -> Self` Return the rounded-up result of dividing self by denominator. **Args:** * ​denominator (`Self`): The denominator. **Returns:** `Self`: The ceiling of dividing numerator by denominator. ### `is_power_of_two` `is_power_of_two(self) -> Bool` Check if the integer is a (non-zero) power of two. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the integer is a power of two, False otherwise. ### `write_to` `write_to(self, mut writer: T)` Formats this integer to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `write_padded` `write_padded[W: Writer](self, mut writer: W, width: Self)` Write the int right-aligned to a set padding. **Parameters:** * ​W ([`Writer`](/mojo/stdlib/io/write/Writer)): A type conforming to the Writable trait. **Args:** * ​writer (`W`): The object to write to. * ​width (`Self`): The amount to pad to the left. ### `__str__` `__str__(self) -> String` Get the integer as a string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation. ### `__repr__` `__repr__(self) -> String` Get the integer as a string. Returns the same `String` as `__str__`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation. ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with this int value. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance. ### `to_python_object` `to_python_object(var self) -> PythonObject` Convert this value to a PythonObject. **Returns:** `PythonObject`: A PythonObject representing the value. **Raises:** If the Python runtime is not initialized or conversion fails.
--- ## Intable
The `Intable` trait describes a type that can be converted to an Int. Any type that conforms to `Intable` or [`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising) can construct an `Int`. This trait requires the type to implement the `__int__()` method. For example: ```mojo struct Foo(Intable): var i: Int fn __int__(self) -> Int: return self.i ``` Now you can construct an `Int`: ```mojo foo = Foo(42) assert_equal(Int(foo), 42) ``` **Note:** If the `__int__()` method can raise an error, use the [`IntableRaising`](/mojo/stdlib/builtin/int/intableraising) trait instead. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__int__` `__int__(self: _Self) -> Int` Get the integral representation of the value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integral representation of the value.
--- ## IntableRaising
The `IntableRaising` trait describes a type can be converted to an Int, but the conversion might raise an error. Any type that conforms to [`Intable`](/mojo/stdlib/builtin/int/Intable) or `IntableRaising` can construct an `Int`. This trait requires the type to implement the `__int__()` method, which can raise an error. For example: ```mojo struct Foo(IntableRaising): var i: Int fn __int__(self) raises -> Int: return self.i ``` Now you can construct an `Int`: ```mojo foo = Foo(42) assert_equal(Int(foo), 42) ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__int__` `__int__(self: _Self) -> Int` Get the integral representation of the value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integral representation of the type. **Raises:** If the type does not have an integral representation.
--- ## index
`index[T: Indexer](idx: T, /) -> Int` Returns the value of `__mlir_index__` for the given value. **Parameters:** * ​T ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): A type conforming to the `Indexer` trait. **Args:** * ​idx (`T`): The value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An `__mlir_type` representing the index value.
--- ## int (Int)
Implements the Int class. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`Int`](/mojo/stdlib/builtin/int/Int): This type represents an integer value. ## Traits * [​`Indexer`](/mojo/stdlib/builtin/int/Indexer): The `Indexer` trait is used for types that can index into a collection or pointer. The type returned is the underlying \_\_mlir\_type.index, enabling types like `UInt` to not have to be converted to an `Int` first. * [​`Intable`](/mojo/stdlib/builtin/int/Intable): The `Intable` trait describes a type that can be converted to an Int. * [​`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising): The `IntableRaising` trait describes a type can be converted to an Int, but the conversion might raise an error. ## Functions * [​`index`](/mojo/stdlib/builtin/int/index-function): Returns the value of `__mlir_index__` for the given value.
--- ## IntLiteral
`@register_passable(trivial)` `struct IntLiteral[value: __mlir_type.`!pop.int\_literal`]` This type represents a static integer literal value with infinite precision. This type is a compile-time construct which stores its value as a parameter. It is typically materialized into other types (like `Int`) for use at runtime. This compile-time representation allows for arbitrary precision constants that would overflow on Int and other fixed precision integer types. ## Parameters * ​value (`__mlir_type.`!pop.int\_literal\`\`): The underlying integer value. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Ceilable`](/mojo/stdlib/math/math/Ceilable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Floorable`](/mojo/stdlib/math/math/Floorable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Indexer`](/mojo/stdlib/builtin/int/Indexer), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`Truncable`](/mojo/stdlib/math/math/Truncable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` Constructor for any value. ### `__bool__` `__bool__(self) -> Bool` Convert this IntLiteral to Bool. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): False Bool value if the value is equal to 0 and True otherwise. ### `__neg__` `__neg__(self) -> IntLiteral[(0 - value)]` Return -self. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): The -self value. ### `__pos__` `__pos__(self) -> Self` Return +self. **Returns:** `Self`: The +self value. ### `__invert__` `__invert__(self) -> IntLiteral[(value ^ -1)]` Return \~self. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): The \~self value. ### `__lt__` `__lt__(self, rhs: IntLiteral[value]) -> Bool` Compare this IntLiteral to the RHS using LT comparison. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The other IntLiteral to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this IntLiteral is less-than the RHS IntLiteral and False otherwise. ### `__le__` `__le__(self, rhs: IntLiteral[value]) -> Bool` Compare this IntLiteral to the RHS using LE comparison. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The other IntLiteral to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this IntLiteral is less-or-equal than the RHS IntLiteral and False otherwise. ### `__eq__` `__eq__(self, rhs: IntLiteral[value]) -> Bool` Compare this IntLiteral to the RHS using EQ comparison. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The other IntLiteral to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this IntLiteral is equal to the RHS IntLiteral and False otherwise. ### `__ne__` `__ne__(self, rhs: IntLiteral[value]) -> Bool` Compare this IntLiteral to the RHS using NE comparison. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The other IntLiteral to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this IntLiteral is non-equal to the RHS IntLiteral and False otherwise. ### `__gt__` `__gt__(self, rhs: IntLiteral[value]) -> Bool` Compare this IntLiteral to the RHS using GT comparison. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The other IntLiteral to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this IntLiteral is greater-than the RHS IntLiteral and False otherwise. ### `__ge__` `__ge__(self, rhs: IntLiteral[value]) -> Bool` Compare this IntLiteral to the RHS using GE comparison. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The other IntLiteral to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this IntLiteral is greater-or-equal than the RHS IntLiteral and False otherwise. ### `__add__` `__add__(self, rhs: IntLiteral[value]) -> IntLiteral[(value + value)]` Return `self + rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The value to add. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self + rhs` value. ### `__sub__` `__sub__(self, rhs: IntLiteral[value]) -> IntLiteral[(value - value)]` Return `self - rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The value to subtract. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self - rhs` value. ### `__mul__` `__mul__(self, rhs: IntLiteral[value]) -> IntLiteral[(value * value)]` Return `self * rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The value to multiply with. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self * rhs` value. ### `__floordiv__` `__floordiv__(self, rhs: IntLiteral[value]) -> IntLiteral[(value // value)]` Return `self // rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The value to divide with. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self // rhs` value. ### `__mod__` `__mod__(self, rhs: IntLiteral[value]) -> IntLiteral[(value % value)]` Return the remainder of self divided by rhs. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The value to divide on. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): The remainder of dividing self by rhs. ### `__lshift__` `__lshift__(self, rhs: IntLiteral[value]) -> IntLiteral[(value << value)]` Return `self << rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The value to shift with. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self << rhs`. ### `__rshift__` `__rshift__(self, rhs: IntLiteral[value]) -> IntLiteral[(value >> value)]` Return `self >> rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The value to shift with. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self >> rhs`. ### `__and__` `__and__(self, rhs: IntLiteral[value]) -> IntLiteral[(value & value)]` Return `self & rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The RHS value. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self & rhs`. ### `__or__` `__or__(self, rhs: IntLiteral[value]) -> IntLiteral[(value | value)]` Return `self | rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The RHS value. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self | rhs`. ### `__xor__` `__xor__(self, rhs: IntLiteral[value]) -> IntLiteral[(value ^ value)]` Return `self ^ rhs`. **Args:** * ​rhs ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The RHS value. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): `self ^ rhs`. ### `__int__` `__int__(self) -> Int` Convert from IntLiteral to Int. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value as an integer of platform-specific width. ### `__uint__` `__uint__(self) -> UInt` Convert from IntLiteral to UInt. **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): The value as an unsigned integer of platform-specific width. ### `__ceil__` `__ceil__(self) -> Self` Return the ceiling of the IntLiteral value, which is itself. **Returns:** `Self`: The IntLiteral value itself. ### `__floor__` `__floor__(self) -> Self` Return the floor of the IntLiteral value, which is itself. **Returns:** `Self`: The IntLiteral value itself. ### `__trunc__` `__trunc__(self) -> Self` Return the truncated of the IntLiteral value, which is itself. **Returns:** `Self`: The IntLiteral value itself. ### `__str__` `__str__(self) -> String` Convert from IntLiteral to String. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The value as a string. ### `__ceildiv__` `__ceildiv__(self, denominator: IntLiteral[value]) -> IntLiteral[(0 - (value // (0 - value)))]` Return the rounded-up result of dividing self by denominator. **Args:** * ​denominator ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The denominator. **Returns:** [`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): The ceiling of dividing numerator by denominator. ### `__mlir_index__` `__mlir_index__(self) -> __mlir_type.index` Convert from IntLiteral to index. **Returns:** `__mlir_type.index`: The corresponding \_\_mlir\_type.index value, interpreting as signed.
--- ## int_literal
Implements the IntLiteral class. ## Structs * [​`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral): This type represents a static integer literal value with infinite precision. This type is a compile-time construct which stores its value as a parameter. It is typically materialized into other types (like `Int`) for use at runtime. This compile-time representation allows for arbitrary precision constants that would overflow on Int and other fixed precision integer types.
--- ## Sized
The `Sized` trait describes a type that has an integer length (such as a string or array). Any type that conforms to `Sized` or [`SizedRaising`](/mojo/stdlib/builtin/len/SizedRaising) works with the built-in [`len()`](/mojo/stdlib/builtin/len/len) function. The `Sized` trait requires a type to implement the `__len__()` method. For example: ```mojo struct Foo(Sized): var length: Int fn __len__(self) -> Int: return self.length ``` You can pass an instance of `Foo` to the `len()` function to get its length: ```mojo var foo = Foo(42) print(len(foo) == 42) ``` ```plaintext True ``` **Note:** If the `__len__()` method can raise an error, use the [`SizedRaising`](/mojo/stdlib/builtin/len/SizedRaising) trait instead. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__len__` `__len__(self: _Self) -> Int` Get the length of the type. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of the type.
--- ## SizedRaising
The `SizedRaising` trait describes a type that has an integer length, which might raise an error if the length can't be determined. Any type that conforms to [`Sized`](/mojo/stdlib/builtin/len/Sized) or `SizedRaising` works with the built-in [`len()`](/mojo/stdlib/builtin/len/len) function. The `SizedRaising` trait requires a type to implement the `__len__()` method, which can raise an error. For example: ```mojo struct Foo(SizedRaising): var length: Int fn __len__(self) raises -> Int: if self.length < 0: raise Error("Length is negative") return self.length ``` You can pass an instance of `Foo` to the `len()` function to get its length: ```mojo def main(): var foo = Foo(42) print(len(foo) == 42) ``` ```plaintext True ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__len__` `__len__(self: _Self) -> Int` Get the length of the type. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of the type. **Raises:** If the length cannot be computed.
--- ## UIntSized
The `Sized` trait describes a type that has an integer length (such as a string or array). Any type that conforms to `Sized` or [`SizedRaising`](/mojo/stdlib/builtin/len/SizedRaising) works with the built-in [`len()`](/mojo/stdlib/builtin/len/len) function. The `Sized` trait requires a type to implement the `__len__()` method. For example: ```mojo struct Foo(Sized): var length: Int fn __len__(self) -> Int: return self.length ``` You can pass an instance of `Foo` to the `len()` function to get its length: ```mojo var foo = Foo(42) print(len(foo) == 42) ``` ```plaintext True ``` **Note:** If the `__len__()` method can raise an error, use the [`SizedRaising`](/mojo/stdlib/builtin/len/SizedRaising) trait instead. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__len__` `__len__(self: _Self) -> UInt` Get the length of the type. **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): The length of the type.
--- ## len
Provides the `len()` function and its associated traits. These are Mojo built-ins, so you don't need to import them. ## Traits * [​`Sized`](/mojo/stdlib/builtin/len/Sized): The `Sized` trait describes a type that has an integer length (such as a string or array). * [​`SizedRaising`](/mojo/stdlib/builtin/len/SizedRaising): The `SizedRaising` trait describes a type that has an integer length, which might raise an error if the length can't be determined. * [​`UIntSized`](/mojo/stdlib/builtin/len/UIntSized): The `Sized` trait describes a type that has an integer length (such as a string or array). ## Functions * [​`len`](/mojo/stdlib/builtin/len/len): Get the length of a value.
--- ## len (Len)
`len[T: Sized](value: T) -> Int` Get the length of a value. **Parameters:** * ​T ([`Sized`](/mojo/stdlib/builtin/len/Sized)): The Sized type. **Args:** * ​value (`T`): The object to get the length of. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of the object. `len[T: SizedRaising](value: T) -> Int` Get the length of a value. **Parameters:** * ​T ([`SizedRaising`](/mojo/stdlib/builtin/len/SizedRaising)): The Sized type. **Args:** * ​value (`T`): The object to get the length of. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of the object. **Raises:** If the length cannot be computed.
--- ## Absable
The `Absable` trait describes a type that defines an absolute value operation. Types that conform to `Absable` will work with the builtin `abs` function. The absolute value operation always returns the same type as the input. For example: ```mojo struct Point(Absable): var x: Float64 var y: Float64 fn __abs__(self) -> Self: return sqrt(self.x * self.x + self.y * self.y) ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__abs__` `__abs__(self: _Self) -> _Self` Get the absolute value of this instance. **Returns:** `_Self`: The absolute value of the instance.
--- ## DivModable
The `DivModable` trait describes a type that defines division and modulo operations returning both quotient and remainder. Types that conform to `DivModable` will work with the builtin `divmod` function, which will return the same type as the inputs. For example: ```mojo @fieldwise_init struct Bytes(DivModable): var size: Int fn __divmod__(self, other: Self) -> Tuple[Self, Self]: var quotient_int = self.size // other.size var remainder_int = self.size % other.size return (Bytes(quotient_int), Bytes(remainder_int)) ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `__divmod__` `__divmod__(self: _Self, denominator: _Self) -> Tuple[_Self, _Self]` Performs division and returns the quotient and the remainder. **Args:** * ​denominator (`_Self`): The value to divide by. **Returns:** `Tuple`: A `Tuple` containing the quotient and the remainder. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## Powable
The `Powable` trait describes a type that defines a power operation (i.e. exponentiation) with the same base and exponent types. Types that conform to `Powable` will work with the builtin `pow` function, which will return the same type as the inputs. For example: ```mojo struct Rational(Powable): var numerator: Float64 var denominator: Float64 fn __init__(out self, numerator: Float64, denominator: Float64): self.numerator = numerator self.denominator = denominator fn __pow__(self, exp: Self) -> Self: var exp_value = exp.numerator / exp.denominator return Self(pow(self.numerator, exp_value), pow(self.denominator, exp_value)) ``` You can now use the \*\* operator to exponentiate objects inside generic functions: ```mojo fn exponentiate[T: Powable](base: T, exp: T) -> T: return base ** exp var base = Rational(Float64(3.0), 5.0) var exp = Rational(Float64(1.0), 2.0) var res = exponentiate(base, exp) ``` ```plaintext raising to power ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__pow__` `__pow__(self: _Self, exp: _Self) -> _Self` Return the value raised to the power of the given exponent. **Args:** * ​exp (`_Self`): The exponent value. **Returns:** `_Self`: The value of `self` raised to the power of `exp`.
--- ## Roundable
The `Roundable` trait describes a type that defines a rounding operation. Types that conform to `Roundable` will work with the builtin `round` function. The round operation always returns the same type as the input. For example: ```mojo @fieldwise_init struct Complex(Roundable): var re: Float64 var im: Float64 fn __round__(self) -> Self: return Self(round(self.re), round(self.im)) fn __round__(self, ndigits: Int) -> Self: return Self(round(self.re, ndigits), round(self.im, ndigits)) ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__round__` `__round__(self: _Self) -> _Self` Get a rounded value for the type. **Returns:** `_Self`: The rounded value. `__round__(self: _Self, ndigits: Int) -> _Self` Get a rounded value for the type. **Args:** * ​ndigits ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of digits after the decimal point. **Returns:** `_Self`: The rounded value.
--- ## abs (Math)
`abs[T: Absable](value: T) -> T` Get the absolute value of the given object. **Parameters:** * ​T ([`Absable`](/mojo/stdlib/builtin/math/Absable)): The type conforming to Absable. **Args:** * ​value (`T`): The object to get the absolute value of. **Returns:** `T`: The absolute value of the object.
--- ## divmod
`divmod[T: DivModable](numerator: T, denominator: T) -> Tuple[T, T]` Performs division and returns the quotient and the remainder. **Parameters:** * ​T ([`DivModable`](/mojo/stdlib/builtin/math/DivModable)): A type conforming to the `DivModable` trait. **Args:** * ​numerator (`T`): The dividend. * ​denominator (`T`): The divisor. **Returns:** `Tuple`: A `Tuple` containing the quotient and the remainder.
--- ## math (Math)
Defines basic math functions for use in the open source parts of the standard library since the `math` package is currently closed source and cannot be depended on in the open source parts of the standard library. These are Mojo built-ins, so you don't need to import them. ## Traits * [​`Absable`](/mojo/stdlib/builtin/math/Absable): The `Absable` trait describes a type that defines an absolute value operation. * [​`DivModable`](/mojo/stdlib/builtin/math/DivModable): The `DivModable` trait describes a type that defines division and modulo operations returning both quotient and remainder. * [​`Powable`](/mojo/stdlib/builtin/math/Powable): The `Powable` trait describes a type that defines a power operation (i.e. exponentiation) with the same base and exponent types. * [​`Roundable`](/mojo/stdlib/builtin/math/Roundable): The `Roundable` trait describes a type that defines a rounding operation. ## Functions * [​`abs`](/mojo/stdlib/builtin/math/abs): Get the absolute value of the given object. * [​`divmod`](/mojo/stdlib/builtin/math/divmod): Performs division and returns the quotient and the remainder. * [​`max`](/mojo/stdlib/builtin/math/max): Gets the maximum of two integers. * [​`min`](/mojo/stdlib/builtin/math/min): Gets the minimum of two integers. * [​`pow`](/mojo/stdlib/builtin/math/pow): Computes the `base` raised to the power of the `exp`. * [​`round`](/mojo/stdlib/builtin/math/round): Get the rounded value of the given object.
--- ## max (Math)
`max(x: Int, y: Int, /) -> Int` Gets the maximum of two integers. **Args:** * ​x ([`Int`](/mojo/stdlib/builtin/int/Int)): Integer input to max. * ​y ([`Int`](/mojo/stdlib/builtin/int/Int)): Integer input to max. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Maximum of x and y. `max(x: UInt, y: UInt, /) -> UInt` Gets the maximum of two integers. **Args:** * ​x ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Integer input to max. * ​y ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Integer input to max. **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): Maximum of x and y. `max[dtype: DType, //](x: SIMD[dtype, size], y: SIMD[dtype, size], /) -> SIMD[dtype, size]` Performs elementwise maximum of x and y. An element of the result SIMD vector will be the maximum of the corresponding elements in x and y. **Constraints:** The type of the inputs must be numeric or boolean. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the SIMD vector. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): First SIMD vector. * ​y ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Second SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector containing the elementwise maximum of x and y. `max[T: Copyable & Comparable](x: T, *ys: T) -> T` Gets the maximum value from a sequence of values. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable)): A type that is both copyable and comparable with greater than. **Args:** * ​x (`T`): The first value to compare. * ​\*ys (`T`): Zero or more additional values to compare. **Returns:** `T`: The maximum value from the input sequence.
--- ## min (Math)
`min(x: Int, y: Int, /) -> Int` Gets the minimum of two integers. **Args:** * ​x ([`Int`](/mojo/stdlib/builtin/int/Int)): Integer input to min. * ​y ([`Int`](/mojo/stdlib/builtin/int/Int)): Integer input to min. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Minimum of x and y. `min(x: UInt, y: UInt, /) -> UInt` Gets the minimum of two integers. **Args:** * ​x ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Integer input to min. * ​y ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): Integer input to min. **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): Minimum of x and y. `min[dtype: DType, //](x: SIMD[dtype, size], y: SIMD[dtype, size], /) -> SIMD[dtype, size]` Gets the elementwise minimum of x and y. An element of the result SIMD vector will be the minimum of the corresponding elements in x and y. **Constraints:** The type of the inputs must be numeric or boolean. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the SIMD vector. **Args:** * ​x ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): First SIMD vector. * ​y ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Second SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD vector containing the elementwise minimum of x and y. `min[T: Copyable & Comparable](x: T, *ys: T) -> T` Gets the minimum value from a sequence of values. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable)): A type that is both copyable and comparable with less than. **Args:** * ​x (`T`): The first value to compare. * ​\*ys (`T`): Zero or more additional values to compare. **Returns:** `T`: The minimum value from the input sequence.
--- ## pow
`pow[T: Powable](base: T, exp: T) -> T` Computes the `base` raised to the power of the `exp`. **Parameters:** * ​T ([`Powable`](/mojo/stdlib/builtin/math/Powable)): A type conforming to the `Powable` trait. **Args:** * ​base (`T`): The base of the power operation. * ​exp (`T`): The exponent of the power operation. **Returns:** `T`: The `base` raised to the power of the `exp`. `pow(base: SIMD[dtype, size], exp: Int) -> SIMD[dtype, size]` Computes elementwise value of a SIMD vector raised to the power of the given integer. **Args:** * ​base ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The first input argument. * ​exp ([`Int`](/mojo/stdlib/builtin/int/Int)): The second input argument. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The `base` elementwise raised raised to the power of `exp`.
--- ## round
`round[T: Roundable, //](number: T) -> T` Get the rounded value of the given object. **Parameters:** * ​T ([`Roundable`](/mojo/stdlib/builtin/math/Roundable)): The type conforming to Roundable. **Args:** * ​number (`T`): The object to get the rounded value of. **Returns:** `T`: The rounded value of the object. `round[T: Roundable, //](number: T, ndigits: Int) -> T` Get the value of this object, rounded to a specified number of digits after the decimal point. **Parameters:** * ​T ([`Roundable`](/mojo/stdlib/builtin/math/Roundable)): The type conforming to Roundable. **Args:** * ​number (`T`): The object to get the rounded value of. * ​ndigits ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of digits to round to. **Returns:** `T`: The rounded value of the object.
--- ## NoneType
`@register_passable(trivial)` `struct NoneType` Represents the absence of a value. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` Construct an instance of the `None` type. `@implicit` `__init__(value: None) -> Self` Construct an instance of the `None` type. **Args:** * ​value (`None`): The MLIR none type to construct from. ### `__str__` `__str__(self) -> String` Returns the string representation of `None`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): `"None"`. ### `__repr__` `__repr__(self) -> String` Returns the string representation of `None`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): `"None"`. ### `write_to` `write_to(self, mut writer: T)` Write `None` to a writer stream. **Args:** * ​writer (`T`): The object to write to.
--- ## none
Defines the builtin `NoneType`. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`NoneType`](/mojo/stdlib/builtin/none/NoneType): Represents the absence of a value.
--- ## range
Implements a 'range' call. These are Mojo built-ins, so you don't need to import them. ## Functions * [​`range`](/mojo/stdlib/builtin/range/range): Constructs a \[0; end) Range.
--- ## range (Range)
`range[T: Indexer, //](end: T) -> _ZeroStartingRange` Constructs a \[0; end) Range. **Parameters:** * ​T ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the end value. **Args:** * ​end (`T`): The end of the range. **Returns:** `_ZeroStartingRange`: The constructed range. `range[T: IntableRaising, //](end: T) -> _ZeroStartingRange` Constructs a \[0; end) Range. **Parameters:** * ​T ([`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising)): The type of the end value. **Args:** * ​end (`T`): The end of the range. **Returns:** `_ZeroStartingRange`: The constructed range. **Raises:** An error if the conversion to an `Int` failed. `range(end: PythonObject) -> _ZeroStartingRange` Constructs a \[0; end) Range from a Python `int`. **Args:** * ​end ([`PythonObject`](/mojo/stdlib/python/python_object/PythonObject)): The end of the range as a Python `int`. **Returns:** `_ZeroStartingRange`: The constructed range. **Raises:** An error if converting `end` to an `Int` failed. `range[T0: Indexer, T1: Indexer, //](start: T0, end: T1) -> _SequentialRange` Constructs a \[start; end) Range. **Parameters:** * ​T0 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the start value. * ​T1 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the end value. **Args:** * ​start (`T0`): The start of the range. * ​end (`T1`): The end of the range. **Returns:** `_SequentialRange`: The constructed range. `range[T0: IntableRaising, T1: IntableRaising](start: T0, end: T1) -> _SequentialRange` Constructs a \[start; end) Range. **Parameters:** * ​T0 ([`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising)): The type of the start value. * ​T1 ([`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising)): The type of the end value. **Args:** * ​start (`T0`): The start of the range. * ​end (`T1`): The end of the range. **Returns:** `_SequentialRange`: The constructed range. **Raises:** An error if converting `start` or `end` to an `Int` failed. `range(start: PythonObject, end: PythonObject) -> _SequentialRange` Constructs a \[start; end) Range from Python `int` objects. **Args:** * ​start ([`PythonObject`](/mojo/stdlib/python/python_object/PythonObject)): The start of the range as a Python `int`. * ​end ([`PythonObject`](/mojo/stdlib/python/python_object/PythonObject)): The end of the range as a Python `int`. **Returns:** `_SequentialRange`: The constructed range. **Raises:** An error if converting `start` or `end` to an `Int` failed. `range[T0: Indexer, T1: Indexer, T2: Indexer, //](start: T0, end: T1, step: T2) -> _StridedRange` Constructs a \[start; end) Range with a given step. **Parameters:** * ​T0 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the start value. * ​T1 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the end value. * ​T2 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the step value. **Args:** * ​start (`T0`): The start of the range. * ​end (`T1`): The end of the range. * ​step (`T2`): The step for the range. **Returns:** `_StridedRange`: The constructed range. `range[T0: IntableRaising, T1: IntableRaising, T2: IntableRaising, //](start: T0, end: T1, step: T2) -> _StridedRange` Constructs a \[start; end) Range with a given step. **Parameters:** * ​T0 ([`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising)): The type of the start value. * ​T1 ([`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising)): The type of the end value. * ​T2 ([`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising)): The type of the step value. **Args:** * ​start (`T0`): The start of the range. * ​end (`T1`): The end of the range. * ​step (`T2`): The step for the range. **Returns:** `_StridedRange`: The constructed range. **Raises:** An error if converting `start`, `end`, or `step` to an `Int` failed. `range(start: PythonObject, end: PythonObject, step: PythonObject) -> _StridedRange` Constructs a \[start; end) Range from Python `int` objects with a given step. **Args:** * ​start ([`PythonObject`](/mojo/stdlib/python/python_object/PythonObject)): The start of the range as a Python `int`. * ​end ([`PythonObject`](/mojo/stdlib/python/python_object/PythonObject)): The end of the range as a Python `int`. * ​step ([`PythonObject`](/mojo/stdlib/python/python_object/PythonObject)): The step for the range as a Python `int`. **Returns:** `_StridedRange`: The constructed range. **Raises:** An error if converting `start`, `end`, or `step` to an `Int` failed. `range(end: UInt) -> _UIntZeroStartingRange` Constructs a \[0; end) Range. **Args:** * ​end ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The end of the range. **Returns:** `_UIntZeroStartingRange`: The constructed range. `range(start: UInt, end: UInt, step: UInt = 1) -> _UIntStridedRange` Constructs a \[start; end) Range with a given step. **Args:** * ​start ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The start of the range. * ​end ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The end of the range. * ​step ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The step for the range. Defaults to 1. **Returns:** `_UIntStridedRange`: The constructed range. `range[dtype: DType, //](end: Scalar[dtype]) -> _ZeroStartingScalarRange[dtype]` Constructs a \[start; end) Range with a given step. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The range dtype. **Args:** * ​end ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The end of the range. **Returns:** `_ZeroStartingScalarRange`: The constructed range. `range[dtype: DType, //](start: Scalar[dtype], end: Scalar[dtype]) -> _SequentialScalarRange[dtype]` Constructs a \[start; end) Range with a given step. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The range dtype. **Args:** * ​start ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The start of the range. * ​end ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The end of the range. **Returns:** `_SequentialScalarRange`: The constructed range. `range[dtype: DType, //](start: Scalar[dtype], end: Scalar[dtype], step: Scalar[dtype]) -> _StridedScalarRange[dtype]` Constructs a \[start; end) Range with a given step. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The range dtype. **Args:** * ​start ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The start of the range. * ​end ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The end of the range. * ​step ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The step for the range. Defaults to 1. **Returns:** `_StridedScalarRange`: The constructed range.
--- ## rebind
Implements type rebind/trait downcast These are Mojo built-ins, so you don't need to import them. ## `comptime` values ### `downcast` `comptime downcast[_Trait: AnyTrait[UnknownDestructibility], T: UnknownDestructibility] = T(_Trait)` Type alias for downcasting a type to conform to a trait. #### Parameters * ​\_Trait (`AnyTrait`): The trait type to downcast to. * ​T ([`UnknownDestructibility`](/stdlib/builtin/anytype/UnknownDestructibility)): The type to downcast. ## Functions * [​`rebind`](/mojo/stdlib/builtin/rebind/rebind): Statically assert that a parameter input type `src_type` resolves to the same type as a parameter result type `dest_type` after function instantiation and "rebind" the input to the result type. * [​`rebind_var`](/mojo/stdlib/builtin/rebind/rebind_var): Statically assert that a parameter input type `src_type` resolves to the same type as a parameter result type `dest_type` after function instantiation and "rebind" the input to the result type, returning a owned variable with an adjusted type. * [​`trait_downcast`](/mojo/stdlib/builtin/rebind/trait_downcast): Downcast a parameter input type `T` and rebind the type such that the return value's type conforms the provided `Trait`. If `T`, after resolving to a concrete type, does not actually conform to `Trait`, a compilation error would occur.
--- ## rebind (Rebind)
`rebind[src_type: AnyTrivialRegType, //, dest_type: AnyTrivialRegType](src: src_type) -> dest_type` Statically assert that a parameter input type `src_type` resolves to the same type as a parameter result type `dest_type` after function instantiation and "rebind" the input to the result type. This function is meant to be used in uncommon cases where a parametric type depends on the value of a constrained parameter in order to manually refine the type with the constrained parameter value. **Parameters:** * ​src\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The original type. * ​dest\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The type to rebind to. **Args:** * ​src (`src_type`): The value to rebind. **Returns:** `dest_type`: The rebound value of `dest_type`. `rebind[src_type: UnknownDestructibility, //, dest_type: UnknownDestructibility](ref src: src_type) -> ref [src] dest_type` Statically assert that a parameter input type `src_type` resolves to the same type as a parameter result type `dest_type` after function instantiation and "rebind" the input to the result type, returning a reference to the input value with an adjusted type. This function is meant to be used in uncommon cases where a parametric type depends on the value of a constrained parameter in order to manually refine the type with the constrained parameter value. **Parameters:** * ​src\_type ([`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility)): The original type. * ​dest\_type ([`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility)): The type to rebind to. **Args:** * ​src (`src_type`): The value to rebind. **Returns:** `ref`: A reference to the value rebound as `dest_type`.
--- ## rebind_var
`rebind_var[src_type: Movable, //, dest_type: Movable](var src: src_type, out dest: dest_type)` Statically assert that a parameter input type `src_type` resolves to the same type as a parameter result type `dest_type` after function instantiation and "rebind" the input to the result type, returning a owned variable with an adjusted type. Unlike `rebind`, this function takes an owned variable and returns an owned variable via moving the value from the input to the output. This function is meant to be used in uncommon cases where a parametric type depends on the value of a constrained parameter in order to manually refine the type with the constrained parameter value. **Parameters:** * ​src\_type ([`Movable`](/mojo/stdlib/builtin/value/Movable)): The original type. * ​dest\_type ([`Movable`](/mojo/stdlib/builtin/value/Movable)): The type to rebind to. **Args:** * ​src (`src_type`): The value to rebind. **Returns:** `dest_type`: An owned value rebound as `dest_type`.
--- ## trait_downcast
`trait_downcast[T: AnyTrivialRegType, //, Trait: AnyTrait[AnyType]](var src: T) -> T(Trait)` Downcast a parameter input type `T` and rebind the type such that the return value's type conforms the provided `Trait`. If `T`, after resolving to a concrete type, does not actually conform to `Trait`, a compilation error would occur. **Parameters:** * ​T ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The original type. * ​Trait (`AnyTrait`): The trait to downcast into. **Args:** * ​src (`T`): The value to downcast. **Returns:** `T(Trait)`: The downcasted value. `trait_downcast[T: UnknownDestructibility, //, Trait: AnyTrait[UnknownDestructibility]](ref src: T) -> ref [src] T(Trait)` Downcast a parameter input type `T` and rebind the type such that the return value's type conforms the provided `Trait`. If `T`, after resolving to a concrete type, does not actually conform to `Trait`, a compilation error would occur. **Parameters:** * ​T ([`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility)): The original type. * ​Trait (`AnyTrait`): The trait to downcast into. **Args:** * ​src (`T`): The value to downcast. **Returns:** `ref`: The downcasted value.
--- ## Representable
A trait that describes a type that has a String representation. Any type that conforms to the `Representable` trait can be used with the `repr` function. Any conforming type must also implement the `__repr__` method. Here is an example: ```mojo struct Dog(Representable): var name: String var age: Int fn __repr__(self) -> String: return String( "Dog(name=", repr(self.name), ", age=", repr(self.age), ")" ) var dog = Dog("Rex", 5) print(repr(dog)) # Dog(name='Rex', age=5) ``` The method `__repr__` should compute the "official" string representation of a type. If at all possible, this should look like a valid Mojo expression that could be used to recreate a struct instance with the same value (given an appropriate environment). So a returned String of the form `module_name.SomeStruct(arg1=value1, arg2=value2)` is advised. If this is not possible, a string of the form `<...some useful description...>` should be returned. The return value must be a `String` instance. This is typically used for debugging, so it is important that the representation is information-rich and unambiguous. Note that when computing the string representation of a collection (`Dict`, `List`, `Set`, etc...), the `repr` function is called on each element, not the `String()` function. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__repr__` `__repr__(self: _Self) -> String` Get the string representation of the type instance, if possible, compatible with Mojo syntax. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the instance.
--- ## repr
Provide the `repr` function. The functions and traits provided here are built-ins, so you don't need to import them. ## Traits * [​`Representable`](/mojo/stdlib/builtin/repr/Representable): A trait that describes a type that has a String representation. ## Functions * [​`repr`](/mojo/stdlib/builtin/repr/repr): Returns the string representation of the given value.
--- ## repr (Repr)
`repr[T: Representable](value: T) -> String` Returns the string representation of the given value. **Parameters:** * ​T ([`Representable`](/mojo/stdlib/builtin/repr/Representable)): The type of `value`. Must implement the `Representable` trait. **Args:** * ​value (`T`): The value to get the string representation of. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the given value. `repr[U: Copyable & Hashable & Equatable & Representable](value: Set[U]) -> String` Returns the string representation of an `Set[U]`. **Parameters:** * ​U ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable) & [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Representable`](/mojo/stdlib/builtin/repr/Representable)): A type that implements `KeyElement` and `Representable`. **Args:** * ​value ([`Set`](/mojo/stdlib/collections/set/Set)): A `Set` of elements `U`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of `Set[U]`. `repr[U: Copyable & Writable](value: LinkedList[U]) -> String` Returns the string representation of an `LinkedList[U]`. **Parameters:** * ​U ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): A type that implements `Copyable` and `Writable`. **Args:** * ​value ([`LinkedList`](/mojo/stdlib/collections/linked_list/LinkedList)): A `LinkedList` of element type `U`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of `LinkedList[U]`. `repr[T: Representable & Copyable](value: Deque[T]) -> String` Returns the string representation of an `Deque[U]`. **Parameters:** * ​T ([`Representable`](/mojo/stdlib/builtin/repr/Representable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): A type that implements `Copyable` and `Representable`. **Args:** * ​value ([`Deque`](/mojo/stdlib/collections/deque/Deque)): A `Deque` of element type `U`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of `Deque[U]`. `repr(value: None) -> String` Returns the string representation of `None`. **Args:** * ​value (`None`): A `None` value. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of `None`.
--- ## ReversibleRange
The `ReversibleRange` trait describes a range that can be reversed. Any type that conforms to `ReversibleRange` works with the builtin [`reversed()`](/mojo/stdlib/builtin/reversed.html) functions. The `ReversibleRange` trait requires the type to define the `__reversed__()` method. **Note**: iterators are currently non-raising. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__reversed__` `__reversed__(self: _Self) -> _StridedRange` Get a reversed iterator for the type. **Note**: iterators are currently non-raising. **Returns:** `_StridedRange`: The reversed iterator of the type.
--- ## reversed
Provides the `reversed` function for reverse iteration over collections. These are Mojo built-ins, so you don't need to import them. ## Traits * [​`ReversibleRange`](/mojo/stdlib/builtin/reversed/ReversibleRange): The `ReversibleRange` trait describes a range that can be reversed. ## Functions * [​`reversed`](/mojo/stdlib/builtin/reversed/reversed): Get a reversed iterator of the input range.
--- ## reversed (Reversed)
`reversed[T: ReversibleRange](value: T) -> _StridedRange` Get a reversed iterator of the input range. **Note**: iterators are currently non-raising. **Parameters:** * ​T ([`ReversibleRange`](/mojo/stdlib/builtin/reversed/ReversibleRange)): The type conforming to ReversibleRange. **Args:** * ​value (`T`): The range to get the reversed iterator of. **Returns:** `_StridedRange`: The reversed iterator of the range. `reversed[T: Copyable](ref value: List[T]) -> _ListIter[T, value_is_origin, False]` Get a reversed iterator of the input list. **Note**: iterators are currently non-raising. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the list. **Args:** * ​value ([`List`](/mojo/stdlib/collections/list/List)): The list to get the reversed iterator of. **Returns:** `_ListIter`: The reversed iterator of the list. `reversed[T: Copyable](ref value: Deque[T]) -> _DequeIter[T, value_is_origin, False]` Get a reversed iterator of the deque. **Note**: iterators are currently non-raising. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. **Args:** * ​value ([`Deque`](/mojo/stdlib/collections/deque/Deque)): The deque to get the reversed iterator of. **Returns:** `_DequeIter`: The reversed iterator of the deque. `reversed[K: KeyElement, V: Copyable, H: Hasher](ref value: Dict[K, V, H]) -> _DictKeyIter[K, V, H, value_is_origin, False]` Get a reversed iterator of the input dict. **Note**: iterators are currently non-raising. **Parameters:** * ​K ([`KeyElement`](/mojo/stdlib/collections/dict/#keyelement)): The type of the keys in the dict. * ​V ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the values in the dict. * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The type of the hasher in the dict. **Args:** * ​value ([`Dict`](/mojo/stdlib/collections/dict/Dict)): The dict to get the reversed iterator of. **Returns:** `_DictKeyIter`: The reversed iterator of the dict keys. `reversed[K: KeyElement, V: Copyable, H: Hasher, dict_mutability: Bool, dict_origin: Origin[dict_mutability]](ref value: _DictValueIter[K, V, H, dict_origin]) -> _DictValueIter[K, V, H, dict_origin, False]` Get a reversed iterator of the input dict values. **Note**: iterators are currently non-raising. **Parameters:** * ​K ([`KeyElement`](/mojo/stdlib/collections/dict/#keyelement)): The type of the keys in the dict. * ​V ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the values in the dict. * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The type of the hasher in the dict. * ​dict\_mutability ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the reference to the dict values is mutable. * ​dict\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the dict values. **Args:** * ​value ([`_DictValueIter`](/mojo/stdlib/collections/dict/_DictValueIter)): The dict values to get the reversed iterator of. **Returns:** [`_DictValueIter`](/mojo/stdlib/collections/dict/_DictValueIter): The reversed iterator of the dict values. `reversed[K: KeyElement, V: Copyable, H: Hasher, dict_mutability: Bool, dict_origin: Origin[dict_mutability]](ref value: _DictEntryIter[K, V, H, dict_origin]) -> _DictEntryIter[K, V, H, dict_origin, False]` Get a reversed iterator of the input dict items. **Note**: iterators are currently non-raising. **Parameters:** * ​K ([`KeyElement`](/mojo/stdlib/collections/dict/#keyelement)): The type of the keys in the dict. * ​V ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the values in the dict. * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The type of the hasher in the dict. * ​dict\_mutability ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the reference to the dict items is mutable. * ​dict\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the dict items. **Args:** * ​value ([`_DictEntryIter`](/mojo/stdlib/collections/dict/_DictEntryIter)): The dict items to get the reversed iterator of. **Returns:** [`_DictEntryIter`](/mojo/stdlib/collections/dict/_DictEntryIter): The reversed iterator of the dict items. `reversed[T: Copyable](value: Span[T, origin]) -> _SpanIter[T, origin, False]` Get a reversed iterator of the input Span. **Note**: iterators are currently non-raising. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the Span. **Args:** * ​value ([`Span`](/mojo/stdlib/memory/span/Span)): The Span to get the reversed iterator of. **Returns:** `_SpanIter`: The reversed iterator of the Span.
--- ## FastMathFlag
`@register_passable` `struct FastMathFlag` Flags for controlling fast-math optimizations in floating-point operations. FastMathFlag provides compile-time controls for various floating-point math optimization modes that trade strict IEEE 754 compliance for performance. Available flags: * `NONE`: No fast-math optimizations. * `NNAN`: Assume operands and results are not NaN. * `NINF`: Assume operands and results are not +/- infinity. * `NSZ`: Treat the sign of a zero as insignificant. * `ARCP`: Allow reciprocal of values. * `CONTRACT`: Allow floating-point contraction (e.g., fused multiply-add). * `AFN`: Allow algebraic function approximations. * `REASSOC`: Allow reassociation of floating-point operations. * `FAST`: Enable all fast-math optimizations. Examples: ```mojo # Use contract flag for fused multiply-add var result = value.fma[FastMathFlag.CONTRACT](multiplier, accumulator) # Use fast flag for maximum optimization var fast_result = value.fma[FastMathFlag.FAST](multiplier, accumulator) ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AFN` `comptime AFN = FastMathFlag(6)` Allow approximate function implementations. ### `ARCP` `comptime ARCP = FastMathFlag(4)` Allow reciprocal approximations. ### `CONTRACT` `comptime CONTRACT = FastMathFlag(5)` Allow floating-point contraction. ### `FAST` `comptime FAST = FastMathFlag(8)` Enable all fast-math optimizations. ### `NINF` `comptime NINF = FastMathFlag(2)` Assume no infinite values. ### `NNAN` `comptime NNAN = FastMathFlag(1)` Assume no NaN values. ### `NONE` `comptime NONE = FastMathFlag(0)` No fast-math optimizations enabled. ### `NSZ` `comptime NSZ = FastMathFlag(3)` Treat the sign of zero as insignificant. ### `REASSOC` `comptime REASSOC = FastMathFlag(7)` Allow reassociation of operations. ## Methods ### `__is__` `__is__(self, other: Self) -> Bool` Compares two FastMathFlag values for identity. **Args:** * ​other (`Self`): The FastMathFlag to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if both flags have the same value, False otherwise.
--- ## SIMD
`@register_passable(trivial)` `struct SIMD[dtype: DType, size: Int]` Represents a vector type that leverages hardware acceleration to process multiple data elements with a single operation. SIMD (Single Instruction, Multiple Data) is a fundamental parallel computing paradigm where a single CPU instruction operates on multiple data elements at once. Modern CPUs can perform 4, 8, 16, or even 32 operations in parallel using SIMD, delivering substantial performance improvements over scalar operations. Instead of processing one value at a time, SIMD processes entire vectors of values with each instruction. For example, when adding two vectors of four values, a scalar operation adds each value in the vector one by one, while a SIMD operation adds all four values at once using vector registers: ```text Scalar operation: SIMD operation: ┌─────────────────────────┐ ┌───────────────────────────┐ │ 4 instructions │ │ 1 instruction │ │ 4 clock cycles │ │ 1 clock cycle │ │ │ │ │ │ ADD a[0], b[0] → c[0] │ │ Vector register A │ │ ADD a[1], b[1] → c[1] │ │ ┌─────┬─────┬─────┬─────┐ │ │ ADD a[2], b[2] → c[2] │ │ │a[0] │a[1] │a[2] │a[3] │ │ │ ADD a[3], b[3] → c[3] │ │ └─────┴─────┴─────┴─────┘ │ └─────────────────────────┘ │ + │ │ Vector register B │ │ ┌─────┬─────┬─────┬─────┐ │ │ │b[0] │b[1] │b[2] │b[3] │ │ │ └─────┴─────┴─────┴─────┘ │ │ ↓ │ │ SIMD_ADD │ │ ↓ │ │ Vector register C │ │ ┌─────┬─────┬─────┬─────┐ │ │ │c[0] │c[1] │c[2] │c[3] │ │ │ └─────┴─────┴─────┴─────┘ │ └───────────────────────────┘ ``` The `SIMD` type maps directly to hardware vector registers and instructions. Mojo automatically generates optimal SIMD code that leverages CPU-specific instruction sets (such as AVX and NEON) without requiring manual intrinsics or assembly programming. This type is the foundation of high-performance CPU computing in Mojo, enabling you to write code that automatically leverages modern CPU vector capabilities while maintaining code clarity and portability. **Caution:** If you declare a SIMD vector size larger than the vector registers of the target hardware, the compiler will break up the SIMD into multiple vector registers for compatibility. However, you should avoid using a vector that's more than 2x the hardware's vector register size because the resulting code will perform poorly. Key properties: * **Hardware-mapped**: Directly maps to CPU vector registers * **Type-safe**: Data types and vector sizes are checked at compile time * **Zero-cost**: No runtime overhead compared to hand-optimized intrinsics * **Portable**: Same code works across different CPU architectures (x86, ARM, etc.) * **Composable**: Seamlessly integrates with Mojo's parallelization features Key APIs: * Construction: * Broadcast single value to all elements: `SIMD[dtype, size](value)` * Initialize with specific values: `SIMD[dtype, size](v1, v2, ...)` * Zero-initialized vector: `SIMD[dtype, size]()` * Element operations: * Arithmetic: `+`, `-`, `*`, `/`, `%`, `//` * Comparison: `==`, `!=`, `<`, `<=`, `>`, `>=` * Math functions: `sqrt()`, `sin()`, `cos()`, `fma()`, etc. * Bit operations: `&`, `|`, `^`, `~`, `<<`, `>>` * Vector operations: * Horizontal reductions: `reduce_add()`, `reduce_mul()`, `reduce_min()`, `reduce_max()` * Element-wise conditional selection: `select(condition, true_case, false_case)` * Vector manipulation: `shuffle()`, `slice()`, `join()`, `split()` * Type conversion: `cast[target_dtype]()` Examples: Vectorized math operations: ```mojo # Process 8 floating-point numbers simultaneously var a = SIMD[DType.float32, 8](1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0) var b = SIMD[DType.float32, 8](2.0) # Broadcast 2.0 to all elements var result = a * b + 1.0 print(result) # => [3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0] ``` Conditional operations with masking: ```mojo # Double the positive values and negate the negative values var values = SIMD[DType.int32, 4](1, -2, 3, -4) var is_positive = values.gt(0) # greater-than: gets SIMD of booleans var result = is_positive.select(values * 2, values * -1) print(result) # => [2, 2, 6, 4] ``` Horizontal reductions: ```mojo # Sum all elements in a vector var data = SIMD[DType.float64, 4](10.5, 20.3, 30.1, 40.7) var total = data.reduce_add() var maximum = data.reduce_max() print(total, maximum) # => 101.6 40.7 ``` **Constraints:** The size of the SIMD vector must be positive and a power of 2. ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of SIMD vector elements. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the SIMD vector (number of elements). ## Implemented traits [`Absable`](/mojo/stdlib/builtin/math/Absable), [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`CeilDivable`](/mojo/stdlib/math/math/CeilDivable), [`Ceilable`](/mojo/stdlib/math/math/Ceilable), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`DivModable`](/mojo/stdlib/builtin/math/DivModable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Floorable`](/mojo/stdlib/math/math/Floorable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Indexer`](/mojo/stdlib/builtin/int/Indexer), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Powable`](/mojo/stdlib/builtin/math/Powable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Roundable`](/mojo/stdlib/builtin/math/Roundable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`Truncable`](/mojo/stdlib/math/math/Truncable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = SIMD[dtype, size]` SIMD types are remapped to the same type when passed to accelerator devices. ### `MAX` `comptime MAX = SIMD[dtype, size](max_or_inf[dtype]())` Gets the maximum value for the SIMD value, potentially +inf. ### `MAX_FINITE` `comptime MAX_FINITE = SIMD[dtype, size](max_finite[dtype]())` Returns the maximum finite value of SIMD value. ### `MIN` `comptime MIN = SIMD[dtype, size](min_or_neg_inf[dtype]())` Gets the minimum value for the SIMD value, potentially -inf. ### `MIN_FINITE` `comptime MIN_FINITE = SIMD[dtype, size](min_finite[dtype]())` Returns the minimum (lowest) finite value of SIMD value. ## Methods ### `__init__` `__init__() -> Self` Default initializer of the SIMD vector. By default the SIMD vectors are initialized to all zeros. `__init__[other_dtype: DType, //](value: SIMD[other_dtype, size], /) -> Self` Initialize from another SIMD of the same size. If the value passed is a scalar, you can initialize a SIMD vector with more elements. Example: ```mojo print(UInt64(UInt8(42))) # 42 print(SIMD[DType.uint64, 4](UInt8(42))) # [42, 42, 42, 42] ``` Casting behavior: ```mojo # Basic casting preserves value within range Int8(UInt8(127)) == Int8(127) # Numbers above signed max wrap to negative using two's complement Int8(UInt8(128)) == Int8(-128) Int8(UInt8(129)) == Int8(-127) Int8(UInt8(256)) == Int8(0) # Negative signed cast to unsigned using two's complement UInt8(Int8(-128)) == UInt8(128) UInt8(Int8(-127)) == UInt8(129) UInt8(Int8(-1)) == UInt8(255) # Truncate precision after downcast and upcast Float64(Float32(Float64(123456789.123456789))) == Float64(123456792.0) # Rightmost bits of significand become 0's on upcast Float64(Float32(0.3)) == Float64(0.30000001192092896) # Numbers equal after truncation of float literal and cast truncation Float32(Float64(123456789.123456789)) == Float32(123456789.123456789) # Float to int/uint floors Int64(Float64(42.2)) == Int64(42) ``` **Parameters:** * ​other\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The type of the value that is being cast from. **Args:** * ​value ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to cast from. `@implicit` `__init__(value: UInt, /) -> Self` Initializes the SIMD vector with an unsigned integer. The unsigned integer value is splatted across all the elements of the SIMD vector. **Args:** * ​value ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The input value. `@implicit` `__init__(value: Int, /) -> Self` Initializes the SIMD vector with a signed integer. The signed integer value is splatted across all the elements of the SIMD vector. **Args:** * ​value ([`Int`](/mojo/stdlib/builtin/int/Int)): The input value. `__init__[T: Floatable, //](value: T, /) -> Float64` Initialize a Float64 from a type conforming to Floatable. **Parameters:** * ​T ([`Floatable`](/mojo/stdlib/builtin/floatable/Floatable)): The Floatable type. **Args:** * ​value (`T`): The object to get the float point representation of. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64) `__init__[T: FloatableRaising, //](out self: Float64, value: T, /)` Initialize a Float64 from a type conforming to FloatableRaising. **Parameters:** * ​T ([`FloatableRaising`](/mojo/stdlib/builtin/floatable/FloatableRaising)): The FloatableRaising type. **Args:** * ​value (`T`): The object to get the float point representation of. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64) **Raises:** If the type does not have a float point representation. `@implicit` `__init__(value: IntLiteral[value], /) -> Self` Initializes the SIMD vector with an integer. The integer value is splatted across all the elements of the SIMD vector. **Args:** * ​value ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The input value. `@implicit` `__init__(value: Bool, /) -> SIMD[DType.bool, size]` Initializes a Scalar with a bool value. Since this constructor does not splat, it can be implicit. **Args:** * ​value ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The bool value to initialize the Scalar with. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) `__init__(*, fill: Bool) -> SIMD[DType.bool, size]` Initializes the SIMD vector with a bool value. The bool value is splatted across all elements of the SIMD vector. **Args:** * ​fill ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The bool value to fill each element of the SIMD vector with. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD) `@implicit` `__init__(value: Scalar[dtype], /) -> Self` Constructs a SIMD vector by splatting a scalar value. The input value is splatted across all elements of the SIMD vector. **Args:** * ​value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value to splat to the elements of the vector. `__init__(*elems: Scalar[dtype], *, __list_literal__: Tuple[] = Tuple[]()) -> Self` Constructs a SIMD vector via a variadic list of elements. The input values are assigned to the corresponding elements of the SIMD vector. **Constraints:** The number of input values is equal to size of the SIMD vector. **Args:** * ​\*elems ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The variadic list of elements from which the SIMD vector is constructed. * ​**list\_literal** ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tell Mojo to use this method for list literals. `@implicit` `__init__(value: FloatLiteral[value], /) -> Self` Initializes the SIMD vector with a float. The value is splatted across all the elements of the SIMD vector. **Args:** * ​value ([`FloatLiteral`](/mojo/stdlib/builtin/float_literal/FloatLiteral)): The input value. `__init__[int_dtype: DType, //](*, from_bits: SIMD[int_dtype, size]) -> Self` Initializes the SIMD vector from the bits of an integral SIMD vector. **Parameters:** * ​int\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integral type of the input SIMD vector. **Args:** * ​from\_bits ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The SIMD vector to copy the bits from. ### `__bool__` `__bool__(self) -> Bool` Converts the SIMD scalar into a boolean value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the SIMD scalar is non-zero and False otherwise. ### `__getitem__` `__getitem__(self, idx: Int) -> Scalar[dtype]` Gets an element from the vector. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The element index. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The value at position `idx`. ### `__setitem__` `__setitem__(mut self, idx: Int, val: Scalar[dtype])` Sets an element in the vector. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index to set. * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value to set. ### `__neg__` `__neg__(self) -> Self` Defines the unary `-` operation. **Returns:** `Self`: The negation of this SIMD vector. ### `__pos__` `__pos__(self) -> Self` Defines the unary `+` operation. **Returns:** `Self`: This SIMD vector. ### `__invert__` `__invert__(self) -> Self` Returns `~self`. **Constraints:** The element type of the SIMD vector must be boolean or integral. **Returns:** `Self`: The `~self` value. ### `__lt__` `__lt__(self, rhs: Self) -> Bool` Compares two Scalars using less-than comparison. **Args:** * ​rhs (`Self`): The Scalar to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is less than `rhs`, False otherwise. ### `__le__` `__le__(self, rhs: Self) -> Bool` Compares two Scalars using less-than-or-equal comparison. **Args:** * ​rhs (`Self`): The Scalar to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is less than or equal to `rhs`, False otherwise. ### `__eq__` `__eq__(self, rhs: Self) -> Bool` Compares two SIMD vectors for equality. **Args:** * ​rhs (`Self`): The SIMD vector to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all elements of the SIMD vectors are equal, False otherwise. ### `__ne__` `__ne__(self, rhs: Self) -> Bool` Compares two SIMD vectors for inequality. **Args:** * ​rhs (`Self`): The SIMD vector to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if any elements of the SIMD vectors are not equal, False otherwise. ### `__gt__` `__gt__(self, rhs: Self) -> Bool` Compares two Scalars using greater-than comparison. **Args:** * ​rhs (`Self`): The Scalar to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is greater than `rhs`, False otherwise. ### `__ge__` `__ge__(self, rhs: Self) -> Bool` Compares two Scalars using greater-than-or-equal comparison. **Args:** * ​rhs (`Self`): The Scalar to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is greater than or equal to `rhs`, False otherwise. ### `__contains__` `__contains__(self, value: Scalar[dtype]) -> Bool` Whether the vector contains the value. **Args:** * ​value ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Whether the vector contains the value. ### `__add__` `__add__(self, rhs: Self) -> Self` Computes `self + rhs`. **Args:** * ​rhs (`Self`): The rhs value. **Returns:** `Self`: A new vector whose element at position `i` is computed as `self[i] + rhs[i]`. ### `__sub__` `__sub__(self, rhs: Self) -> Self` Computes `self - rhs`. **Args:** * ​rhs (`Self`): The rhs value. **Returns:** `Self`: A new vector whose element at position `i` is computed as `self[i] - rhs[i]`. ### `__mul__` `__mul__(self, rhs: Self) -> Self` Computes `self * rhs`. **Args:** * ​rhs (`Self`): The rhs value. **Returns:** `Self`: A new vector whose element at position `i` is computed as `self[i] * rhs[i]`. ### `__truediv__` `__truediv__(self, rhs: Self) -> Self` Computes `self / rhs`. **Args:** * ​rhs (`Self`): The rhs value. **Returns:** `Self`: A new vector whose element at position `i` is computed as `self[i] / rhs[i]`. ### `__floordiv__` `__floordiv__(self, rhs: Self) -> Self` Returns the division of self and rhs rounded down to the nearest integer. **Constraints:** The element type of the SIMD vector must be numeric. **Args:** * ​rhs (`Self`): The value to divide with. **Returns:** `Self`: `floor(self / rhs)` value. ### `__mod__` `__mod__(self, rhs: Self) -> Self` Returns the remainder of self divided by rhs. **Args:** * ​rhs (`Self`): The value to divide with. **Returns:** `Self`: The remainder of dividing self by rhs. ### `__pow__` `__pow__(self, exp: Int) -> Self` Computes the vector raised to the power of the input integer value. **Args:** * ​exp ([`Int`](/mojo/stdlib/builtin/int/Int)): The exponent value. **Returns:** `Self`: A SIMD vector where each element is raised to the power of the specified exponent value. `__pow__(self, exp: Self) -> Self` Computes the vector raised elementwise to the right hand side power. **Args:** * ​exp (`Self`): The exponent value. **Returns:** `Self`: A SIMD vector where each element is raised to the power of the specified exponent value. ### `__lshift__` `__lshift__(self, rhs: Self) -> Self` Returns `self << rhs`. **Constraints:** The element type of the SIMD vector must be integral. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self << rhs`. ### `__rshift__` `__rshift__(self, rhs: Self) -> Self` Returns `self >> rhs`. **Constraints:** The element type of the SIMD vector must be integral. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self >> rhs`. ### `__and__` `__and__(self, rhs: Self) -> Self` Returns `self & rhs`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self & rhs`. ### `__or__` `__or__(self, rhs: Self) -> Self` Returns `self | rhs`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self | rhs`. ### `__xor__` `__xor__(self, rhs: Self) -> Self` Returns `self ^ rhs`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self ^ rhs`. ### `__radd__` `__radd__(self, value: Self) -> Self` Returns `value + self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value + self`. ### `__rsub__` `__rsub__(self, value: Self) -> Self` Returns `value - self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value - self`. ### `__rmul__` `__rmul__(self, value: Self) -> Self` Returns `value * self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value * self`. ### `__rtruediv__` `__rtruediv__(self, value: Self) -> Self` Returns `value / self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value / self`. ### `__rfloordiv__` `__rfloordiv__(self, rhs: Self) -> Self` Returns the division of rhs and self rounded down to the nearest integer. **Constraints:** The element type of the SIMD vector must be numeric. **Args:** * ​rhs (`Self`): The value to divide by self. **Returns:** `Self`: `floor(rhs / self)` value. ### `__rmod__` `__rmod__(self, value: Self) -> Self` Returns `value mod self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value mod self`. ### `__rpow__` `__rpow__(self, base: Self) -> Self` Returns `base ** self`. **Args:** * ​base (`Self`): The base value. **Returns:** `Self`: `base ** self`. ### `__rlshift__` `__rlshift__(self, value: Self) -> Self` Returns `value << self`. **Constraints:** The element type of the SIMD vector must be integral. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value << self`. ### `__rrshift__` `__rrshift__(self, value: Self) -> Self` Returns `value >> self`. **Constraints:** The element type of the SIMD vector must be integral. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value >> self`. ### `__rand__` `__rand__(self, value: Self) -> Self` Returns `value & self`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value & self`. ### `__ror__` `__ror__(self, value: Self) -> Self` Returns `value | self`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value | self`. ### `__rxor__` `__rxor__(self, value: Self) -> Self` Returns `value ^ self`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value ^ self`. ### `__iadd__` `__iadd__(mut self, rhs: Self)` Performs in-place addition. The vector is mutated where each element at position `i` is computed as `self[i] + rhs[i]`. **Args:** * ​rhs (`Self`): The rhs of the addition operation. ### `__isub__` `__isub__(mut self, rhs: Self)` Performs in-place subtraction. The vector is mutated where each element at position `i` is computed as `self[i] - rhs[i]`. **Args:** * ​rhs (`Self`): The rhs of the operation. ### `__imul__` `__imul__(mut self, rhs: Self)` Performs in-place multiplication. The vector is mutated where each element at position `i` is computed as `self[i] * rhs[i]`. **Args:** * ​rhs (`Self`): The rhs of the operation. ### `__itruediv__` `__itruediv__(mut self, rhs: Self)` In-place true divide operator. The vector is mutated where each element at position `i` is computed as `self[i] / rhs[i]`. **Args:** * ​rhs (`Self`): The rhs of the operation. ### `__ifloordiv__` `__ifloordiv__(mut self, rhs: Self)` In-place flood div operator. The vector is mutated where each element at position `i` is computed as `self[i] // rhs[i]`. **Args:** * ​rhs (`Self`): The rhs of the operation. ### `__imod__` `__imod__(mut self, rhs: Self)` In-place mod operator. The vector is mutated where each element at position `i` is computed as `self[i] % rhs[i]`. **Args:** * ​rhs (`Self`): The rhs of the operation. ### `__ipow__` `__ipow__(mut self, rhs: Int)` In-place pow operator. The vector is mutated where each element at position `i` is computed as `pow(self[i], rhs)`. **Args:** * ​rhs ([`Int`](/mojo/stdlib/builtin/int/Int)): The rhs of the operation. ### `__ilshift__` `__ilshift__(mut self, rhs: Self)` Computes `self << rhs` and save the result in `self`. **Constraints:** The element type of the SIMD vector must be integral. **Args:** * ​rhs (`Self`): The RHS value. ### `__irshift__` `__irshift__(mut self, rhs: Self)` Computes `self >> rhs` and save the result in `self`. **Constraints:** The element type of the SIMD vector must be integral. **Args:** * ​rhs (`Self`): The RHS value. ### `__iand__` `__iand__(mut self, rhs: Self)` Computes `self & rhs` and save the result in `self`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​rhs (`Self`): The RHS value. ### `__ixor__` `__ixor__(mut self, rhs: Self)` Computes `self ^ rhs` and save the result in `self`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​rhs (`Self`): The RHS value. ### `__ior__` `__ior__(mut self, rhs: Self)` Computes `self | rhs` and save the result in `self`. **Constraints:** The element type of the SIMD vector must be bool or integral. **Args:** * ​rhs (`Self`): The RHS value. ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. TODO: This will go away soon, when we get better error messages for kernel calls. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): This type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. TODO: This will go away soon, when we get better error messages for kernel calls. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): This type's name. ### `__divmod__` `__divmod__(self, denominator: Self) -> Tuple[SIMD[dtype, size], SIMD[dtype, size]]` Computes both the quotient and remainder using floor division. **Args:** * ​denominator (`Self`): The value to divide on. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): The quotient and remainder as a `Tuple(self // denominator, self % denominator)`. ### `eq` `eq(self, rhs: Self) -> SIMD[DType.bool, size]` Compares two SIMD vectors using elementwise equality. **Args:** * ​rhs (`Self`): The SIMD vector to compare with. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new bool SIMD vector of the same size whose element at position `i` is the value of `self[i] == rhs[i]`. ### `ne` `ne(self, rhs: Self) -> SIMD[DType.bool, size]` Compares two SIMD vectors using elementwise inequality. **Args:** * ​rhs (`Self`): The SIMD vector to compare with. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new bool SIMD vector of the same size whose element at position `i` is the value of `self[i] != rhs[i]`. ### `gt` `gt(self, rhs: Self) -> SIMD[DType.bool, size]` Compares two SIMD vectors using elementwise greater-than comparison. **Args:** * ​rhs (`Self`): The SIMD vector to compare with. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new bool SIMD vector of the same size whose element at position `i` is the value of `self[i] > rhs[i]`. ### `ge` `ge(self, rhs: Self) -> SIMD[DType.bool, size]` Compares two SIMD vectors using elementwise greater-than-or-equal comparison. **Args:** * ​rhs (`Self`): The SIMD vector to compare with. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new bool SIMD vector of the same size whose element at position `i` is the value of `self[i] >= rhs[i]`. ### `lt` `lt(self, rhs: Self) -> SIMD[DType.bool, size]` Compares two SIMD vectors using elementwise less-than comparison. **Args:** * ​rhs (`Self`): The SIMD vector to compare with. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new bool SIMD vector of the same size whose element at position `i` is the value of `self[i] < rhs[i]`. ### `le` `le(self, rhs: Self) -> SIMD[DType.bool, size]` Compares two SIMD vectors using elementwise less-than-or-equal comparison. **Args:** * ​rhs (`Self`): The SIMD vector to compare with. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new bool SIMD vector of the same size whose element at position `i` is the value of `self[i] <= rhs[i]`. ### `__len__` `__len__(self) -> Int` Gets the length of the SIMD vector. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of the SIMD vector. ### `__int__` `__int__(self) -> Int` Casts to the value to an Int. If there is a fractional component, then the fractional part is truncated. **Constraints:** The size of the SIMD vector must be 1. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value as an integer. ### `__mlir_index__` `__mlir_index__(self) -> __mlir_type.index` Convert to index. **Returns:** `__mlir_type.index`: The corresponding \_\_mlir\_type.index value. ### `__float__` `__float__(self) -> Float64` Casts the value to a float. **Constraints:** The size of the SIMD vector must be 1. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): The value as a float. ### `__str__` `__str__(self) -> String` Get the SIMD as a string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation. ### `__repr__` `__repr__(self) -> String` Get the representation of the SIMD value e.g. "SIMD\[DType.int8, 2]\(1, 2)". **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The representation of the SIMD value. ### `__floor__` `__floor__(self) -> Self` Performs elementwise floor on the elements of a SIMD vector. **Returns:** `Self`: The elementwise floor of this SIMD vector. ### `__ceil__` `__ceil__(self) -> Self` Performs elementwise ceiling on the elements of a SIMD vector. **Returns:** `Self`: The elementwise ceiling of this SIMD vector. ### `__trunc__` `__trunc__(self) -> Self` Performs elementwise truncation on the elements of a SIMD vector. **Returns:** `Self`: The elementwise truncated values of this SIMD vector. ### `__abs__` `__abs__(self) -> Self` Defines the absolute value operation. **Returns:** `Self`: The absolute value of this SIMD vector. ### `__round__` `__round__(self) -> Self` Performs elementwise rounding on the elements of a SIMD vector. This rounding goes to the nearest integer with ties away from zero. **Returns:** `Self`: The elementwise rounded value of this SIMD vector. `__round__(self, ndigits: Int) -> Self` Performs elementwise rounding on the elements of a SIMD vector. This rounding goes to the nearest integer with ties away from zero. **Args:** * ​ndigits ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of digits to round to. **Returns:** `Self`: The elementwise rounded value of this SIMD vector. ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with this SIMD value. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance. ### `__ceildiv__` `__ceildiv__(self, denominator: Self) -> Self` Return the rounded-up result of dividing self by denominator. **Args:** * ​denominator (`Self`): The denominator. **Returns:** `Self`: The ceiling of dividing numerator by denominator. ### `cast` `cast[target: DType](self) -> SIMD[target, size]` Casts the elements of the SIMD vector to the target element type. Casting behavior: ```mojo # Basic casting preserves value within range Int8(UInt8(127)) == Int8(127) # Numbers above signed max wrap to negative using two's complement Int8(UInt8(128)) == Int8(-128) Int8(UInt8(129)) == Int8(-127) Int8(UInt8(256)) == Int8(0) # Negative signed cast to unsigned using two's complement UInt8(Int8(-128)) == UInt8(128) UInt8(Int8(-127)) == UInt8(129) UInt8(Int8(-1)) == UInt8(255) # Truncate precision after downcast and upcast Float64(Float32(Float64(123456789.123456789))) == Float64(123456792.0) # Rightmost bits of significand become 0's on upcast Float64(Float32(0.3)) == Float64(0.30000001192092896) # Numbers equal after truncation of float literal and cast truncation Float32(Float64(123456789.123456789)) == Float32(123456789.123456789) # Float to int/uint floors Int64(Float64(42.2)) == Int64(42) ``` **Parameters:** * ​target ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The target DType. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new SIMD vector whose elements have been casted to the target element type. ### `is_power_of_two` `is_power_of_two(self) -> SIMD[DType.bool, size]` Checks if the input value is a power of 2 for each element of a SIMD vector. **Constraints:** The element type of the input vector must be integral. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A SIMD value where the element at position `i` is True if the integer at position `i` of the input value is a power of 2, False otherwise. ### `write_to` `write_to(self, mut writer: T)` Formats this SIMD value to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `to_bits` `to_bits[_dtype: DType = _uint_type_of_width[bit_width_of[dtype]()]()](self) -> SIMD[_dtype, size]` Bitcasts the SIMD vector to an integer SIMD vector. **Parameters:** * ​\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The integer type to cast to. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): An integer representation of the floating-point value. ### `from_bytes` `static from_bytes[*, big_endian: Bool = is_big_endian()](bytes: InlineArray[Byte, size_of[SIMD[dtype, size]]()]) -> Self` Converts a byte array to a vector. **Parameters:** * ​big\_endian ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the byte array is big-endian. **Args:** * ​bytes ([`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray)): The byte array to convert. **Returns:** `Self`: The integer value. ### `as_bytes` `as_bytes[*, big_endian: Bool = is_big_endian()](self) -> InlineArray[Byte, size_of[SIMD[dtype, size]]()]` Convert the vector to a byte array. **Parameters:** * ​big\_endian ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the byte array should be big-endian. **Returns:** [`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray): The byte array. ### `clamp` `clamp(self, lower_bound: Self, upper_bound: Self) -> Self` Clamps the values in a SIMD vector to be in a certain range. Clamp cuts values in the input SIMD vector off at the upper bound and lower bound values. For example, SIMD vector `[0, 1, 2, 3]` clamped to a lower bound of 1 and an upper bound of 2 would return `[1, 1, 2, 2]`. **Args:** * ​lower\_bound (`Self`): Minimum of the range to clamp to. * ​upper\_bound (`Self`): Maximum of the range to clamp to. **Returns:** `Self`: A new SIMD vector containing x clamped to be within lower\_bound and upper\_bound. ### `fma` `fma[flag: FastMathFlag = FastMathFlag.CONTRACT](self, multiplier: Self, accumulator: Self) -> Self` Performs a fused multiply-add operation, i.e. `self*multiplier + accumulator`. **Parameters:** * ​flag ([`FastMathFlag`](/mojo/stdlib/builtin/simd/FastMathFlag)): Fast-math optimization flags to apply (default: CONTRACT). **Args:** * ​multiplier (`Self`): The value to multiply. * ​accumulator (`Self`): The value to accumulate. **Returns:** `Self`: A new vector whose element at position `i` is computed as `self[i]*multiplier[i] + accumulator[i]`. ### `shuffle` `shuffle[*mask: Int](self) -> Self` Shuffles (also called blend) the values of the current vector with the `other` value using the specified mask (permutation). The mask values must be within `2 * len(self)`. **Parameters:** * ​\*mask ([`Int`](/mojo/stdlib/builtin/int/Int)): The permutation to use in the shuffle. **Returns:** `Self`: A new vector with the same length as the mask where the value at position `i` is `(self)[permutation[i]]`. `shuffle[*mask: Int](self, other: Self) -> Self` Shuffles (also called blend) the values of the current vector with the `other` value using the specified mask (permutation). The mask values must be within `2 * len(self)`. **Parameters:** * ​\*mask ([`Int`](/mojo/stdlib/builtin/int/Int)): The permutation to use in the shuffle. **Args:** * ​other (`Self`): The other vector to shuffle with. **Returns:** `Self`: A new vector with the same length as the mask where the value at position `i` is `(self + other)[permutation[i]]`. `shuffle[mask: IndexList[size, element_type=element_type]](self) -> Self` Shuffles (also called blend) the values of the current vector with the `other` value using the specified mask (permutation). The mask values must be within `2 * len(self)`. **Parameters:** * ​mask ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The permutation to use in the shuffle. **Returns:** `Self`: A new vector with the same length as the mask where the value at position `i` is `(self)[permutation[i]]`. `shuffle[mask: IndexList[size, element_type=element_type]](self, other: Self) -> Self` Shuffles (also called blend) the values of the current vector with the `other` value using the specified mask (permutation). The mask values must be within `2 * len(self)`. **Parameters:** * ​mask ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The permutation to use in the shuffle. **Args:** * ​other (`Self`): The other vector to shuffle with. **Returns:** `Self`: A new vector with the same length as the mask where the value at position `i` is `(self + other)[permutation[i]]`. ### `slice` `slice[output_width: Int, /, *, offset: Int = 0](self) -> SIMD[dtype, output_width]` Returns a slice of the vector of the specified width with the given offset. **Constraints:** `output_width + offset` must not exceed the size of this SIMD vector. **Parameters:** * ​output\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The output SIMD vector size. * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The given offset for the slice. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new vector whose elements map to `self[offset:offset+output_width]`. ### `insert` `insert[*, offset: Int = 0](self, value: SIMD[dtype, size]) -> Self` Returns a new vector where the elements between `offset` and `offset + input_width` have been replaced with the elements in `value`. **Parameters:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset to insert at. This must be a multiple of value's size. **Args:** * ​value ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to be inserted. **Returns:** `Self`: A new vector whose elements at `self[offset:offset+input_width]` contain the values of `value`. ### `join` `join(self, other: Self) -> SIMD[dtype, (2 * size)]` Concatenates the two vectors together. **Args:** * ​other (`Self`): The other SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new vector `self_0, self_1, ..., self_n, other_0, ..., other_n`. ### `interleave` `interleave(self, other: Self) -> SIMD[dtype, (2 * size)]` Constructs a vector by interleaving two input vectors. **Args:** * ​other (`Self`): The other SIMD vector. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new vector `self_0, other_0, ..., self_n, other_n`. ### `split` `split(self) -> Tuple[SIMD[dtype, (size // 2)], SIMD[dtype, (size // 2)]]` Splits the SIMD vector into 2 subvectors. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): A new vector `self_0:N/2, self_N/2:N`. ### `deinterleave` `deinterleave(self) -> Tuple[SIMD[dtype, (size // 2)], SIMD[dtype, (size // 2)]]` Constructs two vectors by deinterleaving the even and odd lanes of the vector. **Constraints:** The vector size must be greater than 1. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): Two vectors the first of the form `self_0, self_2, ..., self_{n-2}` and the other being `self_1, self_3, ..., self_{n-1}`. ### `reduce` `reduce[func: fn[width: Int](SIMD[dtype, width], SIMD[dtype, width]) -> SIMD[dtype, width], size_out: Int = 1](self) -> SIMD[dtype, size_out]` Reduces the vector using a provided reduce operator. **Constraints:** `size_out` must not exceed width of the vector. **Parameters:** * ​func (`fn[width: Int](SIMD[dtype, width], SIMD[dtype, width]) -> SIMD[dtype, width]`): The reduce function to apply to elements in this SIMD. * ​size\_out ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the reduction. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new scalar which is the reduction of all vector elements. `reduce[func: fn[width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width], size_out: Int = 1](self) -> SIMD[dtype, size_out]` Reduces the vector using a provided reduce operator. **Constraints:** `size_out` must not exceed width of the vector. **Parameters:** * ​func (`fn[width: Int](SIMD[dtype, width], SIMD[dtype, width]) capturing -> SIMD[dtype, width]`): The reduce function to apply to elements in this SIMD. * ​size\_out ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the reduction. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new scalar which is the reduction of all vector elements. ### `reduce_max` `reduce_max[size_out: Int = 1](self) -> SIMD[dtype, size_out]` Reduces the vector using the `max` operator. **Constraints:** `size_out` must not exceed width of the vector. The element type of the vector must be integer or FP. **Parameters:** * ​size\_out ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the reduction. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The maximum element of the vector. ### `reduce_min` `reduce_min[size_out: Int = 1](self) -> SIMD[dtype, size_out]` Reduces the vector using the `min` operator. **Constraints:** `size_out` must not exceed width of the vector. The element type of the vector must be integer or FP. **Parameters:** * ​size\_out ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the reduction. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The minimum element of the vector. ### `reduce_add` `reduce_add[size_out: Int = 1](self) -> SIMD[dtype, size_out]` Reduces the vector using the `add` operator. **Constraints:** `size_out` must not exceed width of the vector. **Parameters:** * ​size\_out ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the reduction. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The sum of all vector elements. ### `reduce_mul` `reduce_mul[size_out: Int = 1](self) -> SIMD[dtype, size_out]` Reduces the vector using the `mul` operator. **Constraints:** `size_out` must not exceed width of the vector. The element type of the vector must be integer or FP. **Parameters:** * ​size\_out ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the reduction. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The product of all vector elements. ### `reduce_and` `reduce_and[size_out: Int = 1](self) -> SIMD[dtype, size_out]` Reduces the vector using the bitwise `&` operator. **Constraints:** `size_out` must not exceed width of the vector. The element type of the vector must be integer or boolean. **Parameters:** * ​size\_out ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the reduction. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The reduced vector. ### `reduce_or` `reduce_or[size_out: Int = 1](self) -> SIMD[dtype, size_out]` Reduces the vector using the bitwise `|` operator. **Constraints:** `size_out` must not exceed width of the vector. The element type of the vector must be integer or boolean. **Parameters:** * ​size\_out ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the reduction. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The reduced vector. ### `reduce_bit_count` `reduce_bit_count(self) -> Int` Returns the total number of bits set in the SIMD vector. **Constraints:** Must be either an integral or a boolean type. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Count of set bits across all elements of the vector. ### `select` `select[_dtype: DType](self, true_case: SIMD[_dtype, size], false_case: SIMD[_dtype, size]) -> SIMD[_dtype, size]` Selects the values of the `true_case` or the `false_case` based on the current boolean values of the SIMD vector. **Constraints:** The element type of the vector must be boolean. **Parameters:** * ​\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The element type of the input and output SIMD vectors. **Args:** * ​true\_case ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The values selected if the positional value is True. * ​false\_case ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The values selected if the positional value is False. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): A new vector of the form `[true_case[i] if elem else false_case[i] for i, elem in enumerate(self)]`. ### `rotate_left` `rotate_left[shift: Int](self) -> Self` Shifts the elements of a SIMD vector to the left by `shift` elements (with wrap-around). **Constraints:** `-size <= shift < size` **Parameters:** * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of positions by which to rotate the elements of SIMD vector to the left (with wrap-around). **Returns:** `Self`: The SIMD vector rotated to the left by `shift` elements (with wrap-around). ### `rotate_right` `rotate_right[shift: Int](self) -> Self` Shifts the elements of a SIMD vector to the right by `shift` elements (with wrap-around). **Constraints:** `-size < shift <= size` **Parameters:** * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of positions by which to rotate the elements of SIMD vector to the right (with wrap-around). **Returns:** `Self`: The SIMD vector rotated to the right by `shift` elements (with wrap-around). ### `shift_left` `shift_left[shift: Int](self) -> Self` Shifts the elements of a SIMD vector to the left by `shift` elements (no wrap-around, fill with zero). **Constraints:** `0 <= shift <= size` **Parameters:** * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of positions by which to rotate the elements of SIMD vector to the left (no wrap-around, fill with zero). **Returns:** `Self`: The SIMD vector rotated to the left by `shift` elements (no wrap-around, fill with zero). ### `shift_right` `shift_right[shift: Int](self) -> Self` Shifts the elements of a SIMD vector to the right by `shift` elements (no wrap-around, fill with zero). **Constraints:** `0 <= shift <= size` **Parameters:** * ​shift ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of positions by which to rotate the elements of SIMD vector to the right (no wrap-around, fill with zero). **Returns:** `Self`: The SIMD vector rotated to the right by `shift` elements (no wrap-around, fill with zero). ### `reversed` `reversed(self) -> Self` Reverses the SIMD vector by indexes. Examples: ```mojo print(SIMD[DType.uint8, 4](1, 2, 3, 4).reversed()) # [4, 3, 2, 1] ``` **Returns:** `Self`: The by index reversed vector.
--- ## simd (Simd)
Implements SIMD primitives and abstractions. Provides high-performance SIMD primitives and abstractions for vectorized computation in Mojo. It enables efficient data-parallel operations by leveraging hardware vector processing units across different architectures. Key Features: 1. Architecture-agnostic SIMD abstractions with automatic hardware detection 2. Optimized vector operations for common numerical computations 3. Explicit control over vectorization strategies and memory layouts 4. Zero-cost abstractions that compile to efficient machine code 5. Support for different vector widths and element types Primary Components: * Vector types: Strongly-typed vector containers with element-wise operations * SIMD intrinsics: Low-level access to hardware SIMD instructions * Vectorized algorithms: Common algorithms optimized for SIMD execution * Memory utilities: Aligned memory allocation and vector load/store operations Performance Considerations: * Vector width selection should match target hardware capabilities * Memory alignment affects load/store performance * Data layout transformations may be necessary for optimal vectorization Integration: This module is designed to work seamlessly with other Mojo numerical computing components, including tensor operations, linear algebra routines, and domain-specific libraries for machine learning and scientific computing. ## `comptime` values ### `BFloat16` `comptime BFloat16 = BFloat16` Represents a 16-bit brain floating point value. ### `Byte` `comptime Byte = UInt8` Represents a byte (backed by an 8-bit unsigned integer). ### `Float16` `comptime Float16 = Float16` Represents a 16-bit floating point value. ### `Float32` `comptime Float32 = Float32` Represents a 32-bit floating point value. ### `Float4_e2m1fn` `comptime Float4_e2m1fn = Float4_e2m1fn` Represents a 4-bit `e2m1` floating point format. This type is encoded as `s.ee.m` and defined by the [Open Compute MX Format Specification](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf): * (s)ign: 1 bit * (e)xponent: 2 bits * (m)antissa: 1 bits * exponent\_bias: 1 ### `Float64` `comptime Float64 = Float64` Represents a 64-bit floating point value. ### `Float8_e4m3fn` `comptime Float8_e4m3fn = Float8_e4m3fn` Represents the E4M3 floating point format defined in the [OFP8 standard](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1). This type is named differently across libraries and vendors, for example: * Mojo, PyTorch, JAX, and LLVM refer to it as `e4m3fn`. * OCP, NVIDIA CUDA, and AMD ROCm refer to it as `e4m3`. In these contexts, they are all referring to the same finite type specified in the OFP8 standard above, encoded as `seeeemmm`: * (s)ign: 1 bit * (e)xponent: 4 bits * (m)antissa: 3 bits * exponent bias: 7 * nan: 01111111, 11111111 * -0: 10000000 * fn: finite (no inf or -inf encodings) ### `Float8_e4m3fnuz` `comptime Float8_e4m3fnuz = Float8_e4m3fnuz` Represents an 8-bit e4m3fnuz floating point format. This type is encoded as `seeeemmm`: * (s)ign: 1 bit * (e)xponent: 4 bits * (m)antissa: 3 bits * exponent bias: 8 * nan: 10000000 * fn: finite (no inf or -inf encodings) * uz: unsigned zero (no -0 encoding) ### `Float8_e5m2` `comptime Float8_e5m2 = Float8_e5m2` Represents the 8-bit E5M2 floating point format. This type is from the [OFP8 standard](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1), encoded as `seeeeemm`: * (s)ign: 1 bit * (e)xponent: 5 bits * (m)antissa: 2 bits * exponent bias: 15 * nan: {0,1}11111{01,10,11} * inf: 01111100 * -inf: 11111100 * -0: 10000000 ### `Float8_e5m2fnuz` `comptime Float8_e5m2fnuz = Float8_e5m2fnuz` Represents an 8-bit floating point format. This type is encoded as `seeeeemm`: * (s)ign: 1 bit * (e)xponent: 5 bits * (m)antissa: 2 bits * exponent bias: 16 * nan: 10000000 * fn: finite (no inf or -inf encodings) * uz: unsigned zero (no -0 encoding) ### `Float8_e8m0fnu` `comptime Float8_e8m0fnu = Float8_e8m0fnu` Represents the 8-bit E8M0FNU floating point format. This type is defined in the [OCP MX spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf), encoded as `eeeeeeee`: * (e)xponent: 8 bits * (m)antissa: 0 bits * exponent bias: 127 * nan: 11111111 * fn: finite (no inf or -inf encodings) * u: unsigned (no sign bit or zero value) ### `Int128` `comptime Int128 = Int128` Represents a 128-bit signed scalar integer. ### `Int16` `comptime Int16 = Int16` Represents a 16-bit signed scalar integer. ### `Int256` `comptime Int256 = Int256` Represents a 256-bit signed scalar integer. ### `Int32` `comptime Int32 = Int32` Represents a 32-bit signed scalar integer. ### `Int64` `comptime Int64 = Int64` Represents a 64-bit signed scalar integer. ### `Int8` `comptime Int8 = Int8` Represents an 8-bit signed scalar integer. ### `Scalar` `comptime Scalar = Scalar[?]` Represents a scalar dtype. ### `U8x16` `comptime U8x16 = SIMD[DType.uint8, 16]` A 16-element vector of unsigned 8-bit integers. ### `UInt128` `comptime UInt128 = UInt128` Represents a 128-bit unsigned scalar integer. ### `UInt16` `comptime UInt16 = UInt16` Represents a 16-bit unsigned scalar integer. ### `UInt256` `comptime UInt256 = UInt256` Represents a 256-bit unsigned scalar integer. ### `UInt32` `comptime UInt32 = UInt32` Represents a 32-bit unsigned scalar integer. ### `UInt64` `comptime UInt64 = UInt64` Represents a 64-bit unsigned scalar integer. ### `UInt8` `comptime UInt8 = UInt8` Represents an 8-bit unsigned scalar integer. ## Structs * [​`FastMathFlag`](/mojo/stdlib/builtin/simd/FastMathFlag): Flags for controlling fast-math optimizations in floating-point operations. * [​`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Represents a vector type that leverages hardware acceleration to process multiple data elements with a single operation.
--- ## sort
Implements the built-in `sort` function. These are Mojo built-ins, so you don't need to import them. ## `comptime` values ### `insertion_sort_threshold` `comptime insertion_sort_threshold = 32` Threshold below which insertion sort is used instead of quicksort. ## Functions * [​`partition`](/mojo/stdlib/builtin/sort/partition): Partition the input buffer inplace such that first k elements are the largest (or smallest if cmp\_fn is < operator) elements. The ordering of the first k elements is undefined. * [​`sort`](/mojo/stdlib/builtin/sort/sort): Sort a span in-place. The function doesn't return anything, the span is updated in-place.
--- ## partition
`partition[T: Copyable, origin: MutOrigin, //, cmp_fn: fn(T, T) capturing -> Bool](span: Span[T, origin], k: Int)` Partition the input buffer inplace such that first k elements are the largest (or smallest if cmp\_fn is < operator) elements. The ordering of the first k elements is undefined. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): Type of the underlying data. * ​origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): Origin of span. * ​cmp\_fn (`fn(T, T) capturing -> Bool`): Comparison functor of (T, T) capturing \[\_] -> Bool type. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): Input buffer. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Index of the partition element.
--- ## sort (Sort)
`sort[T: Copyable, origin: MutOrigin, //, cmp_fn: fn(T, T) capturing -> Bool, *, stable: Bool = False, __disambiguate: NoneType = None](span: Span[T, origin])` Sort a span in-place. The function doesn't return anything, the span is updated in-place. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): Copyable type of the underlying data. * ​origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): Origin of span. * ​cmp\_fn (`fn(T, T) capturing -> Bool`): The comparison function. * ​stable ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the sort should be stable. * ​\_\_disambiguate ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): Give the Scalar overload higher priority. Do not pass explicitly. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The span to be sorted. `sort[dtype: DType, origin: MutOrigin, //, cmp_fn: fn(Scalar[dtype], Scalar[dtype]) capturing -> Bool, *, stable: Bool = False](span: Span[Scalar[dtype], origin])` Sort a span of Scalar elements in-place. The function doesn't return anything, the list is updated in-place. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of elements. * ​origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): Origin of span. * ​cmp\_fn (`fn(Scalar[dtype], Scalar[dtype]) capturing -> Bool`): The comparison function. * ​stable ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the sort should be stable. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The span to be sorted. `sort[origin: MutOrigin, //, cmp_fn: fn(Int, Int) capturing -> Bool, *, stable: Bool = False](span: Span[Int, origin])` Sort a span in-place. The function doesn't return anything, the span is updated in-place. **Parameters:** * ​origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): Origin of span. * ​cmp\_fn (`fn(Int, Int) capturing -> Bool`): The comparison function. * ​stable ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the sort should be stable. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The span to be sorted. `sort[origin: MutOrigin, //, *, stable: Bool = False](span: Span[Int, origin])` Sort a span inplace. The function doesn't return anything, the span is updated in-place. **Parameters:** * ​origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): Origin of span. * ​stable ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the sort should be stable. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The span to be sorted. `sort[T: Copyable & Comparable, origin: MutOrigin, //, *, stable: Bool = False](span: Span[T, origin])` Sort a span of comparable elements in-place. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable)): The order comparable collection element type. * ​origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): Origin of span. * ​stable ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the sort should be stable. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The span to be sorted.
--- ## Stringable
The `Stringable` trait describes a type that can be converted to a [`String`](/mojo/stdlib/collections/string/String). Any type that conforms to `Stringable` or [`StringableRaising`](/mojo/stdlib/builtin/str/StringableRaising) works with the built-in [`print()`](/mojo/stdlib/builtin/io/print) and [`String()`](/mojo/stdlib/builtin/str/str) functions. The `Stringable` trait requires the type to define the `__str__()` method. For example: ```mojo struct Foo(Stringable): var s: String fn __str__(self) -> String: return self.s ``` Now you can pass an instance of `Foo` to the `String()` function to get back a `String`: ```mojo var foo = Foo("test") print(String(foo) == "test") ``` ```plaintext True ``` **Note:** If the `__str__()` method might raise an error, use the [`StringableRaising`](/mojo/stdlib/builtin/str/StringableRaising) trait, instead. About the difference between `__repr__()` and `__str__()`: The method `__repr__` computes the "official" string representation of an object while `__str__` computes the "informal" or nicely printable string representation of an object. This method differs from `__repr__()` in that there is no expectation that `__str__()` return a valid Mojo expression: a more convenient or concise representation can be used. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__str__` `__str__(self: _Self) -> String` Get the string representation of the type. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the type.
--- ## StringableRaising
The StringableRaising trait describes a type that can be converted to a [`String`](/mojo/stdlib/collections/string/String). Any type that conforms to [`Stringable`](/mojo/stdlib/builtin/str/Stringable) or `StringableRaising` works with the built-in [`print()`](/mojo/stdlib/builtin/io/print) and [`String()`](/mojo/stdlib/builtin/str/str) functions. The `StringableRaising` trait requires the type to define the `__str__()` method, which can raise an error. For example: ```mojo struct Foo(StringableRaising): var s: String fn __str__(self) raises -> String: if self.s == "": raise Error("Empty String") return self.s ``` Now you can pass an instance of `Foo` to the `String()` function to get back a `String`: ```mojo def main(): var foo = Foo("test") print(String(foo) == "test") ``` ```plaintext True ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__str__` `__str__(self: _Self) -> String` Get the string representation of the type. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the type. **Raises:** If there is an error when computing the string representation of the type.
--- ## str
Provides the `Stringable` and `StringableRaising` traits. These are Mojo built-ins, so you don't need to import them. ## Traits * [​`Stringable`](/mojo/stdlib/builtin/str/Stringable): The `Stringable` trait describes a type that can be converted to a [`String`](/mojo/stdlib/collections/string/String). * [​`StringableRaising`](/mojo/stdlib/builtin/str/StringableRaising): The StringableRaising trait describes a type that can be converted to a [`String`](/mojo/stdlib/collections/string/String).
--- ## StringLiteral
`@register_passable(trivial)` `struct StringLiteral[value: __mlir_type.`!kgen.string`]` This type represents a string literal. String literals are all null-terminated for compatibility with C APIs, but this is subject to change. String literals store their length as an integer, and this does not include the null terminator. ## Parameters * ​value (`__mlir_type.`!kgen.string\`\`): The underlying string value. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`ConvertibleToPython`](/mojo/stdlib/python/conversions/ConvertibleToPython), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`FloatableRaising`](/mojo/stdlib/builtin/floatable/FloatableRaising), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`PathLike`](/mojo/stdlib/os/pathlike/PathLike), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__() -> Self` Constructor for any value. ### `__bool__` `__bool__(self) -> Bool` Convert the string to a bool value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the string is not empty. ### `__getitem__` `__getitem__[I: Indexer, //](self, idx: I) -> StaticString` Gets the character at the specified position. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The inferred type of an indexer argument. **Args:** * ​idx (`I`): The index value. **Returns:** `StaticString`: A StringSlice view containing the character at the specified position. ### `__lt__` `__lt__(self, rhs: StringSlice[origin]) -> Bool` Compare this value to the RHS using lesser than (LT) comparison. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The other value to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this is strictly less than the RHS and False otherwise. ### `__le__` `__le__(self, rhs: StringSlice[origin]) -> Bool` Compare this value to the RHS using lesser than or equal to (LE) comparison. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The other value to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this is less than or equal to the RHS and False otherwise. ### `__eq__` `__eq__(self, rhs: StringSlice[origin]) -> Bool` Compare two string literals for equality. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if they are equal. ### `__ne__` `__ne__(self, rhs: StringSlice[origin]) -> Bool` Compare two string literals for inequality. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to compare. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if they are not equal. ### `__gt__` `__gt__(self, rhs: StringSlice[origin]) -> Bool` Compare this value to the RHS using greater than (GT) comparison. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The other value to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this is strictly greater than the RHS and False otherwise. ### `__ge__` `__ge__(self, rhs: StringSlice[origin]) -> Bool` Compare this value to the RHS using greater than or equal to (GE) comparison. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The other value to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this is greater than or equal to the RHS and False otherwise. ### `__add__` `__add__(self, rhs: StringLiteral[value]) -> StringLiteral[#pop.string_concat]\` Concatenate two string literals. **Args:** * ​rhs ([`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral)): The string to concat. **Returns:** [`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral): The concatenated string. ### `__mul__` `__mul__(self, n: Int) -> String` Concatenates the string `n` times. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of times to concatenate the string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string concatenated `n` times. ### `to_python_object` `to_python_object(var self) -> PythonObject` Convert this value to a PythonObject. **Returns:** [`PythonObject`](/mojo/stdlib/python/python_object/PythonObject): A PythonObject representing the value. **Raises:** If the Python runtime is not initialized or conversion fails. ### `__len__` `__len__(self) -> Int` Get the string length. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of this value. ### `__int__` `__int__(self) -> Int` Parses the given string as a base-10 integer and returns that value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer value that represents the string. **Raises:** If the string cannot be parsed as a valid base-10 integer. ### `__float__` `__float__(self) -> Float64` Parses the string as a floating-point number and returns that value. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): A float value that represents the string. **Raises:** If the string cannot be parsed as a valid floating-point number. ### `__str__` `__str__(self) -> String` Convert the string literal to a string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A new string. ### `__repr__` `__repr__(self) -> String` Return a representation of this value. You don't need to call this method directly, use `repr("...")` instead. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A new representation of the string. ### `__fspath__` `__fspath__(self) -> String` Return the file system path representation of the object. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The file system path representation as a string. ### `__iter__` `__iter__(self) -> CodepointSliceIter[StaticConstantOrigin]` Return an iterator over the string literal. **Returns:** `CodepointSliceIter`: An iterator over the string. ### `__reversed__` `__reversed__(self) -> CodepointSliceIter[StaticConstantOrigin, False]` Iterate backwards over the string, returning immutable references. **Returns:** `CodepointSliceIter`: A reversed iterator over the string. ### `__merge_with__` `__merge_with__[other_type: AnyStruct[StringLiteral[value]]](self) -> StaticString` Returns a StaticString after merging with another string literal. **Parameters:** * ​other\_type (`AnyStruct`): The type of the string literal to merge with. **Returns:** `StaticString`: A StaticString after merging with the specified `other_type`. ### `byte_length` `byte_length(self) -> Int` Get the string length in bytes. Notes: This does not include the trailing null terminator in the count. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of this string in bytes. ### `unsafe_ptr` `unsafe_ptr(self) -> UnsafePointer[Byte, StaticConstantOrigin]` Get raw pointer to the underlying data. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The raw pointer to the data. ### `unsafe_cstr_ptr` `unsafe_cstr_ptr(self) -> UnsafePointer[c_char, StaticConstantOrigin]` Retrieves a C-string-compatible pointer to the underlying memory. The returned pointer is guaranteed to be NUL terminated, and not null. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The pointer to the underlying memory. ### `as_c_string_slice` `as_c_string_slice(self) -> CStringSlice[StaticConstantOrigin]` Return a `CStringSlice` to the underlying memory of the string. **Returns:** `CStringSlice`: The `CStringSlice` of the string. ### `as_string_slice` `as_string_slice(self) -> StaticString` Returns a string slice of this static string literal. **Returns:** `StaticString`: A string slice pointing to this static string literal. ### `as_bytes` `as_bytes(self) -> Span[Byte, StaticConstantOrigin]` Returns a contiguous Span of the bytes owned by this string. **Returns:** [`Span`](/mojo/stdlib/memory/span/Span): A contiguous slice pointing to the bytes owned by this string. ### `write_to` `write_to(self, mut writer: T)` Formats this string literal to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `find` `find(self, substr: StringSlice[StaticConstantOrigin], start: Int = 0) -> Int` Finds the offset of the first occurrence of `substr` starting at `start`. If not found, returns -1. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to find. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset from which to find. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The offset of `substr` relative to the beginning of the string. ### `rfind` `rfind(self, substr: StringSlice[StaticConstantOrigin], start: Int = 0) -> Int` Finds the offset of the last occurrence of `substr` starting at `start`. If not found, returns -1. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to find. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset from which to find. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The offset of `substr` relative to the beginning of the string. ### `count` `count(self, substr: StringSlice[origin]) -> Int` Return the number of non-overlapping occurrences of substring `substr` in the string literal. If sub is empty, returns the number of empty strings between characters which is the length of the string plus one. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to count. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of occurrences of `substr`. ### `lower` `lower(self) -> String` Returns a copy of the string literal with all cased characters converted to lowercase. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A new string where cased letters have been converted to lowercase. ### `upper` `upper(self) -> String` Returns a copy of the string literal with all cased characters converted to uppercase. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A new string where cased letters have been converted to uppercase. ### `rjust` `rjust(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> String` Returns the string literal right justified in a string of specified width. Pads the string literal on the left with the specified fill character so that the total length of the resulting string equals `width`. If the original string literal is already longer than or equal to `width`, returns the string literal unchanged (as a `String`). Examples: ```mojo var s = "hello" print(s.rjust(10)) # " hello" print(s.rjust(10, "*")) # "*****hello" print(s.rjust(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A right-justified string of length `width`, or the original string literal (as a `String`) if its length is already greater than or equal to `width`. ### `ljust` `ljust(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> String` Returns the string literal left justified in a string of specified width. Pads the string literal on the right with the specified fill character so that the total length of the resulting string equals `width`. If the original string literal is already longer than or equal to `width`, returns the string literal unchanged (as a `String`). Examples: ```mojo var s = "hello" print(s.ljust(10)) # "hello " print(s.ljust(10, "*")) # "hello*****" print(s.ljust(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A left-justified string of length `width`, or the original string literal (as a `String`) if its length is already greater than or equal to `width`. ### `center` `center(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> String` Returns the string literal center justified in a string of specified width. Pads the string literal on both sides with the specified fill character so that the total length of the resulting string equals `width`. If the padding needed is odd, the extra character goes on the right side. If the original string literal is already longer than or equal to `width`, returns the string literal unchanged (as a `String`). Examples: ```mojo var s = "hello" print(s.center(10)) # " hello " print(s.center(11, "*")) # "***hello***" print(s.center(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A center-justified string of length `width`, or the original string literal (as a `String`) if its length is already greater than or equal to `width`. ### `startswith` `startswith(self, prefix: StringSlice[origin], start: Int = 0, end: Int = -1) -> Bool` Checks if the string literal starts with the specified prefix between start and end positions. Returns True if found and False otherwise. **Args:** * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix to check. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The start offset from which to check. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end offset from which to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `self[start:end]` is prefixed by the input prefix. ### `endswith` `endswith(self, suffix: StringSlice[origin], start: Int = 0, end: Int = -1) -> Bool` Checks if the string literal end with the specified suffix between start and end positions. Returns True if found and False otherwise. **Args:** * ​suffix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The suffix to check. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The start offset from which to check. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end offset from which to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `self[start:end]` is suffixed by the input suffix. ### `isdigit` `isdigit(self) -> Bool` Returns True if all characters in the string literal are digits. Note that this currently only works with ASCII strings. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all characters are digits else False. ### `isupper` `isupper(self) -> Bool` Returns True if all cased characters in the string literal are uppercase and there is at least one cased character. Note that this currently only works with ASCII strings. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all cased characters in the string literal are uppercase and there is at least one cased character, False otherwise. ### `islower` `islower(self) -> Bool` Returns True if all cased characters in the string literal are lowercase and there is at least one cased character. Note that this currently only works with ASCII strings. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all cased characters in the string literal are lowercase and there is at least one cased character, False otherwise. ### `strip` `strip(self) -> StaticString` Return a copy of the string literal with leading and trailing whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. **Returns:** `StaticString`: A string with no leading or trailing whitespaces. `strip(self, chars: StringSlice[origin]) -> StaticString` Return a copy of the string literal with leading and trailing characters removed. **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** `StaticString`: A string with no leading or trailing characters. ### `rstrip` `rstrip(self, chars: StringSlice[origin]) -> StaticString` Return a copy of the string literal with trailing characters removed. **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** `StaticString`: A string with no trailing characters. `rstrip(self) -> StaticString` Return a copy of the string with trailing whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. **Returns:** `StaticString`: A copy of the string with no trailing whitespaces. ### `lstrip` `lstrip(self, chars: StringSlice[origin]) -> StaticString` Return a copy of the string with leading characters removed. **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** `StaticString`: A copy of the string with no leading characters. `lstrip(self) -> StaticString` Return a copy of the string with leading whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. **Returns:** `StaticString`: A copy of the string with no leading whitespaces. ### `format` `format[*Ts: Stringable & Representable](self, *args: *Ts) -> String` Produce a formatted string using the current string as a template. The template, or "format string" can contain literal text and/or replacement fields delimited with curly braces (`{}`). Returns a copy of the format string with the replacement fields replaced with string representations of the `args` arguments. For more information, see the discussion in the [`format` module](/mojo/stdlib/collections/string/format/). Example: ```mojo # Manual indexing: print("{0} {1} {0}".format("Mojo", 1.125)) # Mojo 1.125 Mojo # Automatic indexing: print("{} {}".format(True, "hello world")) # True hello world ``` **Parameters:** * ​\*Ts ([`Stringable`](/mojo/stdlib/builtin/str/Stringable) & [`Representable`](/mojo/stdlib/builtin/repr/Representable)): The types of substitution values that implement `Representable` and `Stringable` (to be changed and made more flexible). **Args:** * ​\*args (`*Ts`): The substitution values. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The template with the given values substituted. **Raises:** If the format string is invalid or argument count/types don't match. ### `join` `join[T: Copyable & Writable, //](self, elems: Span[T, origin]) -> String` Joins string elements using the current string as a delimiter. Defaults to writing to the stack if total bytes of `elems` is less than `buffer_size`, otherwise will allocate once to the heap and write directly into that. The `buffer_size` defaults to 4096 bytes to match the default page size on arm64 and x86-64. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): The type of the elements. Must implement the `Copyable`, and `Writable` traits. **Args:** * ​elems ([`Span`](/mojo/stdlib/memory/span/Span)): The input values. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The joined string. ### `split` `split(self, sep: StringSlice[origin]) -> List[StaticString]` Split the string by a separator. Examples: ```mojo # Splitting a space _ = StringSlice("hello world").split(" ") # ["hello", "world"] # Splitting adjacent separators _ = StringSlice("hello,,world").split(",") # ["hello", "", "world"] # Splitting with starting or ending separators _ = StringSlice(",1,2,3,").split(",") # ['', '1', '2', '3', ''] # Splitting with an empty separator _ = StringSlice("123").split("") # ['', '1', '2', '3', ''] ``` **Args:** * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to split on. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: StringSlice[origin], maxsplit: Int) -> List[StaticString]` Split the string by a separator. Examples: ```mojo # Splitting with maxsplit _ = StringSlice("1,2,3").split(",", maxsplit=1) # ['1', '2,3'] # Splitting with starting or ending separators _ = StringSlice(",1,2,3,").split(",", maxsplit=1) # ['', '1,2,3,'] # Splitting with an empty separator _ = StringSlice("123").split("", maxsplit=1) # ['', '123'] ``` **Args:** * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to split on. * ​maxsplit ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum amount of items to split from String. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: NoneType = None) -> List[StaticString]` Split the string by every Whitespace separator. Examples: ```mojo # Splitting an empty string or filled with whitespaces _ = StringSlice(" ").split() # [] _ = StringSlice("").split() # [] # Splitting a string with leading, trailing, and middle whitespaces _ = StringSlice(" hello world ").split() # ["hello", "world"] # Splitting adjacent universal newlines: _ = StringSlice( "hello \t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029world" ).split() # ["hello", "world"] ``` **Args:** * ​sep ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): None. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: NoneType = None, *, maxsplit: Int) -> List[StaticString]` Split the string by every Whitespace separator. Examples: ```mojo # Splitting with maxsplit _ = StringSlice("1 2 3").split(maxsplit=1) # ['1', '2 3'] ``` **Args:** * ​sep ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): None. * ​maxsplit ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum amount of items to split from String. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator.
--- ## string_literal
Implements the StringLiteral struct. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral): This type represents a string literal.
--- ## swap
Implements the built-in `swap` function. These are Mojo built-ins, so you don't need to import them. ## Functions * [​`swap`](/mojo/stdlib/builtin/swap/swap): Swaps the two given arguments.
--- ## swap (Swap)
`swap[T: Movable](mut lhs: T, mut rhs: T)` Swaps the two given arguments. **Parameters:** * ​T ([`Movable`](/mojo/stdlib/builtin/value/Movable)): Constrained to Movable types. **Args:** * ​lhs (`T`): Argument value swapped with rhs. * ​rhs (`T`): Argument value swapped with lhs.
--- ## Tuple
`struct Tuple[*element_types: Movable]` The type of a literal tuple expression. A tuple consists of zero or more values, separated by commas. ## Parameters * ​\*element\_types ([`Movable`](/mojo/stdlib/builtin/value/Movable)): The elements type. ## Fields * ​storage (`__mlir_type.`!kgen.pack<:variadic\> element\_types>\`\`): The underlying storage for the tuple. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ## Methods ### `__init__` `__init__(out self: Tuple[])` Construct an empty tuple. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) `__init__(out self, var *args: *element_types)` Construct the tuple. **Args:** * ​\*args (`*element_types`): Initial values. `__init__(out self, *, var storage: VariadicPack[is_owned, origin, Movable, element_types])` Construct the tuple from a low-level internal representation. **Args:** * ​storage ([`VariadicPack`](/mojo/stdlib/builtin/variadics/VariadicPack)): The variadic pack storage to construct from. `__init__[*elt_types: Movable & Defaultable](out self: Tuple[elt_types])` Construct a tuple with default-initialized elements. **Parameters:** * ​\*elt\_types ([`Movable`](/mojo/stdlib/builtin/value/Movable) & [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable)): The types of the elements contained in the Tuple. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple) ### `__copyinit__` `__copyinit__(out self, existing: Self)` Copy construct the tuple. **Args:** * ​existing (`Self`): The value to copy from. ### `__moveinit__` `__moveinit__(out self, deinit existing: Self)` Move construct the tuple. **Args:** * ​existing (`Self`): The value to move from. ### `__del__` `__del__(deinit self)` Destructor that destroys all of the elements. ### `__getitem__` `__getitem__[idx: Int](ref self) -> ref [self] element_types[idx._mlir_value]` Get a reference to an element in the tuple. **Parameters:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The element to return. **Returns:** `ref`: A reference to the specified element. ### `__lt__` `__lt__[self_elt_types: Variadic[Movable & Comparable], other_elt_types: Variadic[Movable & Comparable], //](self: Tuple[self_elt_types], other: Tuple[other_elt_types]) -> Bool` Compare this tuple to another tuple using less than comparison. **Parameters:** * ​self\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Comparable]`): The types of the elements contained in the Tuple. * ​other\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Comparable]`): The types of the elements contained in the other Tuple. **Args:** * ​other ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The other tuple to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this tuple is less than the other tuple, False otherwise. ### `__le__` `__le__[self_elt_types: Variadic[Movable & Comparable], other_elt_types: Variadic[Movable & Comparable], //](self: Tuple[self_elt_types], other: Tuple[other_elt_types]) -> Bool` Compare this tuple to another tuple using less than or equal to comparison. **Parameters:** * ​self\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Comparable]`): The types of the elements contained in the Tuple. * ​other\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Comparable]`): The types of the elements contained in the other Tuple. **Args:** * ​other ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The other tuple to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this tuple is less than or equal to the other tuple, False otherwise. ### `__eq__` `__eq__[self_elt_types: Variadic[Movable & Equatable], other_elt_types: Variadic[Movable & Equatable]](self: Tuple[self_elt_types], other: Tuple[other_elt_types]) -> Bool` Compare this tuple to another tuple using equality comparison. **Parameters:** * ​self\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Equatable]`): The types of the elements contained in the Tuple. * ​other\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Equatable]`): The types of the elements contained in the other Tuple. **Args:** * ​other ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The other tuple to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this tuple is equal to the other tuple, False otherwise. ### `__ne__` `__ne__[self_elt_types: Variadic[Movable & Equatable], other_elt_types: Variadic[Movable & Equatable]](self: Tuple[self_elt_types], other: Tuple[other_elt_types]) -> Bool` Compare this tuple to another tuple using inequality comparison. **Parameters:** * ​self\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Equatable]`): The types of the elements contained in the Tuple. * ​other\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Equatable]`): The types of the elements contained in the other Tuple. **Args:** * ​other ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The other tuple to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this tuple is not equal to the other tuple, False otherwise. ### `__gt__` `__gt__[self_elt_types: Variadic[Movable & Comparable], other_elt_types: Variadic[Movable & Comparable], //](self: Tuple[self_elt_types], other: Tuple[other_elt_types]) -> Bool` Compare this tuple to another tuple using greater than comparison. **Parameters:** * ​self\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Comparable]`): The types of the elements contained in the Tuple. * ​other\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Comparable]`): The types of the elements contained in the other Tuple. **Args:** * ​other ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The other tuple to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this tuple is greater than the other tuple, False otherwise. ### `__ge__` `__ge__[self_elt_types: Variadic[Movable & Comparable], other_elt_types: Variadic[Movable & Comparable], //](self: Tuple[self_elt_types], other: Tuple[other_elt_types]) -> Bool` Compare this tuple to another tuple using greater than or equal to comparison. **Parameters:** * ​self\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Comparable]`): The types of the elements contained in the Tuple. * ​other\_elt\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic) & `Comparable]`): The types of the elements contained in the other Tuple. **Args:** * ​other ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The other tuple to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this tuple is greater than or equal to the other tuple, False otherwise. ### `__contains__` `__contains__[T: Equatable](self, value: T) -> Bool` Return whether the tuple contains the specified value. For example: ```mojo var t = Tuple(True, 1, 2.5) if 1 in t: print("t contains 1") ``` **Parameters:** * ​T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable)): The type of the value. **Args:** * ​value (`T`): The value to search for. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the value is in the tuple, False otherwise. ### `__len__` `static __len__() -> Int` Return the number of elements in the tuple. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The tuple length. `__len__(self) -> Int` Get the number of elements in the tuple. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The tuple length. ### `reverse` `reverse(deinit self, out result: Tuple[#kgen.variadic.reduce<#kgen.variadic<> : !kgen.variadic>, #kgen.param.decl.ref<"element_types"> : !kgen.variadic>, #kgen.gen<#kgen.variadic.concat<#kgen.variadic<*(0,0), [variadic_get(:variadic> *(0,1), add(mul(*(0,2), -1), #kgen.variadic.size<#kgen.param.index.ref<0, 1> : !kgen.variadic>>, -1))]> : !kgen.variadic>>>> : !kgen.generator>, "VA": variadic>, "idx": index>variadic>>>>])` Return a new tuple with the elements in reverse order. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): A new tuple with the elements in reverse order. ### `concat` `concat[*other_element_types: Movable](deinit self, deinit other: Tuple[other_element_types], out result: Tuple[#kgen.variadic.concat<#kgen.variadic : !kgen.variadic>>>])` Return a new tuple that concatenates this tuple with another. **Parameters:** * ​\*other\_element\_types ([`Movable`](/mojo/stdlib/builtin/value/Movable)): The types of the elements contained in the other Tuple. **Args:** * ​other ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The other tuple to concatenate. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): A new tuple with the concatenated elements.
--- ## tuple (Tuple)
Implements the Tuple type. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): The type of a literal tuple expression.
--- ## Origin
`@register_passable(trivial)` `struct Origin[mut: Bool]` This represents a origin reference for a memory value. ## Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the origin is mutable. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `cast_from` `comptime cast_from[mut: Bool, //, o: Origin[mut]] = o._mlir_origin` Cast an existing Origin to be of the specified mutability. This is a low-level way to coerce Origin mutability. This should be used rarely, typically when building low-level fundamental abstractions. Strongly consider alternatives before reaching for this "escape hatch". Safety: This is an UNSAFE operation if used to cast an immutable origin to a mutable origin. Examples: Cast a mutable origin to be immutable: ```mojo struct Container[mut: Bool, //, origin: Origin[mut]]: var data: Int fn imm_borrow(self) -> Container[ImmutOrigin.cast_from[origin]]: pass ``` #### Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): * ​o ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin to cast. ### `external` `comptime external = origin_of()` An external origin of the given mutability. The external origin is guaranteed not to alias any existing origins. An external origin implies there is no previously existing value that this origin aliases. Therefore, the compiler cannot track the origin or the value's lifecycle. The external origin is useful when interfacing with memory that comes from outside the current Mojo program.
--- ## type_aliases
Defines some type aliases. These are Mojo built-ins, so you don't need to import them. ## `comptime` values ### `AnyTrivialRegType` `comptime AnyTrivialRegType = AnyTrivialRegType` Represents any register passable Mojo data type. ### `ImmutAnyOrigin` `comptime ImmutAnyOrigin = ImmutAnyOrigin` The immutable origin that might access any memory value. ### `ImmutOrigin` `comptime ImmutOrigin = ImmutOrigin` Immutable origin reference type. ### `MutAnyOrigin` `comptime MutAnyOrigin = MutAnyOrigin` The mutable origin that might access any memory value. ### `MutOrigin` `comptime MutOrigin = MutOrigin` Mutable origin reference type. ### `Never` `comptime Never = Never` A type that can never have an instance constructed, used as a function result by functions that never return. ### `OriginSet` `comptime OriginSet = OriginSet` A set of origin parameters. ### `StaticConstantOrigin` `comptime StaticConstantOrigin = StaticConstantOrigin` An origin for strings and other always-immutable static constants. ## Structs * [​`Origin`](/mojo/stdlib/builtin/type_aliases/Origin): This represents a origin reference for a memory value.
--- ## UInt
`@register_passable(trivial)` `struct UInt` This type represents an unsigned integer. The size of this unsigned integer is platform-dependent. If you wish to use a fixed size unsigned integer, consider using `UInt8`, `UInt16`, `UInt32`, or `UInt64`. ## Implemented traits [`Absable`](/mojo/stdlib/builtin/math/Absable), [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`CeilDivable`](/mojo/stdlib/math/math/CeilDivable), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`DivModable`](/mojo/stdlib/builtin/math/DivModable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Indexer`](/mojo/stdlib/builtin/int/Indexer), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BITWIDTH` `comptime BITWIDTH = UInt(bit_width_of[DType.uindex]())` The bit width of the integer type. ### `device_type` `comptime device_type = UInt` UInt is remapped to the same type when passed to accelerator devices. ### `MAX` `comptime MAX = UInt.__init__[Scalar[DType.uindex]](Scalar[DType.uindex].MAX)` Returns the maximum integer value. ### `MIN` `comptime MIN = UInt.__init__[Scalar[DType.uindex]](Scalar[DType.uindex].MIN)` Returns the minimum value of type. ## Methods ### `__init__` `__init__() -> Self` Default constructor that produces zero. `@implicit` `__init__(value: IntLiteral[value]) -> Self` Construct UInt from the given IntLiteral value. **Args:** * ​value ([`IntLiteral`](/mojo/stdlib/builtin/int_literal/IntLiteral)): The init value. `__init__(value: Int) -> Self` Construct UInt from the given Int value. **Args:** * ​value ([`Int`](/mojo/stdlib/builtin/int/Int)): The init value. `__init__[T: Indexer](value: T) -> Self` Construct UInt from the given Indexable value. **Parameters:** * ​T ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type that that can index into a collection or pointer. **Args:** * ​value (`T`): The init value. ### `__bool__` `__bool__(self) -> Bool` Convert this Int to Bool. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): False Bool value if the value is equal to 0 and True otherwise. ### `__pos__` `__pos__(self) -> Self` Return +self. **Returns:** `Self`: The +self value. ### `__invert__` `__invert__(self) -> Self` Return \~self. **Returns:** `Self`: The \~self value. ### `__lt__` `__lt__(self, rhs: Self) -> Bool` Return whether this UInt is strictly less than another. **Args:** * ​rhs (`Self`): The other UInt to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this UInt is less than the other UInt and False otherwise. ### `__le__` `__le__(self, rhs: Self) -> Bool` Compare this Int to the RHS using LE comparison. **Args:** * ​rhs (`Self`): The other UInt to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this Int is less-or-equal than the RHS Int and False otherwise. ### `__eq__` `__eq__(self, rhs: Self) -> Bool` Compare this UInt to the RHS using EQ comparison. **Args:** * ​rhs (`Self`): The other UInt to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this UInt is equal to the RHS UInt and False otherwise. ### `__ne__` `__ne__(self, rhs: Self) -> Bool` Compare this UInt to the RHS using NE comparison. **Args:** * ​rhs (`Self`): The other UInt to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this UInt is non-equal to the RHS UInt and False otherwise. ### `__gt__` `__gt__(self, rhs: Self) -> Bool` Return whether this UInt is strictly greater than another. **Args:** * ​rhs (`Self`): The other UInt to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this UInt is greater than the other UInt and False otherwise. ### `__ge__` `__ge__(self, rhs: Self) -> Bool` Return whether this UInt is greater than or equal to another. **Args:** * ​rhs (`Self`): The other UInt to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this UInt is greater than or equal to the other UInt and False otherwise. ### `__add__` `__add__(self, rhs: Self) -> Self` Return `self + rhs`. **Args:** * ​rhs (`Self`): The value to add. **Returns:** `Self`: `self + rhs` value. ### `__sub__` `__sub__(self, rhs: Self) -> Self` Return `self - rhs`. **Args:** * ​rhs (`Self`): The value to subtract. **Returns:** `Self`: `self - rhs` value. ### `__mul__` `__mul__(self, rhs: Self) -> Self` Return `self * rhs`. **Args:** * ​rhs (`Self`): The value to multiply with. **Returns:** `Self`: `self * rhs` value. ### `__truediv__` `__truediv__(self, rhs: Self) -> Float64` Return the floating point division of `self` and `rhs`. **Args:** * ​rhs (`Self`): The value to divide on. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): `Float64(self)/Float64(rhs)` value. ### `__floordiv__` `__floordiv__(self, rhs: Self) -> Self` Return the division of `self` and `rhs` rounded down to the nearest integer. **Args:** * ​rhs (`Self`): The value to divide on. **Returns:** `Self`: `floor(self/rhs)` value. ### `__mod__` `__mod__(self, rhs: Self) -> Self` Return the remainder of self divided by rhs. **Args:** * ​rhs (`Self`): The value to divide on. **Returns:** `Self`: The remainder of dividing self by rhs. ### `__pow__` `__pow__(self, exp: Self) -> Self` Return the value raised to the power of the given exponent. Computes the power of an integer using the Russian Peasant Method. **Args:** * ​exp (`Self`): The exponent value. **Returns:** `Self`: The value of `self` raised to the power of `exp`. ### `__lshift__` `__lshift__(self, rhs: Self) -> Self` Return `self << rhs`. **Args:** * ​rhs (`Self`): The value to shift with. **Returns:** `Self`: `self << rhs`. ### `__rshift__` `__rshift__(self, rhs: Self) -> Self` Return `self >> rhs`. **Args:** * ​rhs (`Self`): The value to shift with. **Returns:** `Self`: `self >> rhs`. ### `__and__` `__and__(self, rhs: Self) -> Self` Return `self & rhs`. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self & rhs`. ### `__or__` `__or__(self, rhs: Self) -> Self` Return `self | rhs`. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self | rhs`. ### `__xor__` `__xor__(self, rhs: Self) -> Self` Return `self ^ rhs`. **Args:** * ​rhs (`Self`): The RHS value. **Returns:** `Self`: `self ^ rhs`. ### `__radd__` `__radd__(self, value: Self) -> Self` Return `value + self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value + self`. ### `__rsub__` `__rsub__(self, value: Self) -> Self` Return `value - self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value - self`. ### `__rmul__` `__rmul__(self, value: Self) -> Self` Return `value * self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value * self`. ### `__rfloordiv__` `__rfloordiv__(self, value: Self) -> Self` Return `value // self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value // self`. ### `__rmod__` `__rmod__(self, value: Self) -> Self` Return `value % self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value % self`. ### `__rpow__` `__rpow__(self, value: Self) -> Self` Return `pow(value,self)`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `pow(value,self)`. ### `__rlshift__` `__rlshift__(self, value: Self) -> Self` Return `value << self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value << self`. ### `__rrshift__` `__rrshift__(self, value: Self) -> Self` Return `value >> self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value >> self`. ### `__rand__` `__rand__(self, value: Self) -> Self` Return `value & self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value & self`. ### `__ror__` `__ror__(self, value: Self) -> Self` Return `value | self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value | self`. ### `__rxor__` `__rxor__(self, value: Self) -> Self` Return `value ^ self`. **Args:** * ​value (`Self`): The other value. **Returns:** `Self`: `value ^ self`. ### `__iadd__` `__iadd__(mut self, rhs: Self)` Compute `self + rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__isub__` `__isub__(mut self, rhs: Self)` Compute `self - rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__imul__` `__imul__(mut self, rhs: Self)` Compute self\*rhs and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__itruediv__` `__itruediv__(mut self, rhs: Self)` Compute `self / rhs`, convert to int, and save the result in self. Since `floor(self / rhs)` is equivalent to `self // rhs`, this yields the same as `__ifloordiv__`. **Args:** * ​rhs (`Self`): The RHS value. ### `__ifloordiv__` `__ifloordiv__(mut self, rhs: Self)` Compute `self // rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__imod__` `__imod__(mut self, rhs: Self)` Compute `self % rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__ipow__` `__ipow__(mut self, rhs: Self)` Compute `pow(self, rhs)` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__ilshift__` `__ilshift__(mut self, rhs: Self)` Compute `self << rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__irshift__` `__irshift__(mut self, rhs: Self)` Compute `self >> rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__iand__` `__iand__(mut self, rhs: Self)` Compute `self & rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__ixor__` `__ixor__(mut self, rhs: Self)` Compute `self ^ rhs` and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `__ior__` `__ior__(mut self, rhs: Self)` Compute self|rhs and save the result in self. **Args:** * ​rhs (`Self`): The RHS value. ### `get_type_name` `static get_type_name() -> String` Gets this type's name, for use in error messages when handing arguments to kernels. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): This type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): This type's name. ### `__divmod__` `__divmod__(self, rhs: Self) -> Tuple[UInt, UInt]` Computes both the quotient and remainder using integer division. **Args:** * ​rhs (`Self`): The value to divide on. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): The quotient and remainder as a `Tuple(self // rhs, self % rhs)`. ### `__mlir_index__` `__mlir_index__(self) -> __mlir_type.index` Convert to index. **Returns:** `__mlir_type.index`: The corresponding \_\_mlir\_type.index value. ### `__int__` `__int__(self) -> Int` Gets the integral value, wrapping to a negative number on overflow. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value as an integer. ### `__abs__` `__abs__(self) -> Self` Return the absolute value of the UInt value. **Returns:** `Self`: The absolute value. ### `__ceil__` `__ceil__(self) -> Self` Return the ceiling of the UInt value, which is itself. **Returns:** `Self`: The UInt value itself. ### `__floor__` `__floor__(self) -> Self` Return the floor of the UInt value, which is itself. **Returns:** `Self`: The UInt value itself. ### `__round__` `__round__(self) -> Self` Return the rounded value of the UInt value, which is itself. **Returns:** `Self`: The UInt value itself. `__round__(self, ndigits: Self) -> Self` Return the rounded value of the UInt value, which is itself. **Args:** * ​ndigits (`Self`): The number of digits to round to. **Returns:** `Self`: The UInt value itself if ndigits >= 0 else the rounded value. ### `__trunc__` `__trunc__(self) -> Self` Return the truncated UInt value, which is itself. **Returns:** `Self`: The Int value itself. ### `__ceildiv__` `__ceildiv__(self, denominator: Self) -> Self` Return the rounded-up result of dividing self by denominator. **Args:** * ​denominator (`Self`): The denominator. **Returns:** `Self`: The ceiling of dividing numerator by denominator. ### `is_power_of_two` `is_power_of_two(self) -> Bool` Check if the integer is a (non-zero) power of two. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the integer is a power of two, False otherwise. ### `write_to` `write_to(self, mut writer: T)` Formats this integer to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `__str__` `__str__(self) -> String` Convert this UInt to a string. A small example. ```mojo x = UInt(50) assert_equal(String(x), "50") ``` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of this UInt. ### `__repr__` `__repr__(self) -> String` Convert this UInt to a string. A small example. ```mojo x = UInt(50) assert_equal(repr(x), "UInt(50)") ``` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of this UInt. ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with this uint value. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance.
--- ## uint (Uint)
Implements the UInt class. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`UInt`](/mojo/stdlib/builtin/uint/UInt): This type represents an unsigned integer.
--- ## Copyable
The Copyable trait denotes a type whose value can be explicitly copied. Example implementing the `Copyable` trait on `Foo`, which requires the `__copyinit__` method: ```mojo struct Foo(Copyable): var s: String fn __init__(out self, s: String): self.s = s fn __copyinit__(out self, other: Self): print("copying value") self.s = other.s ``` You can now copy objects inside a generic function: ```mojo fn copy_return[T: Copyable](foo: T) -> T: var copy = foo.copy() return copy^ var foo = Foo("test") var res = copy_return(foo) ``` ```plaintext copying value ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## Defaultable
The `Defaultable` trait describes a type with a default constructor. Implementing the `Defaultable` trait requires the type to define an `__init__` method with no arguments: ```mojo struct Foo(Defaultable): var s: String fn __init__(out self): self.s = "default" ``` You can now construct a generic `Defaultable` type: ```mojo fn default_init[T: Defaultable]() -> T: return T() var foo = default_init[Foo]() print(foo.s) ``` ```plaintext default ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ## Required methods ### `__init__` `__init__(out self: _Self)` Create a default instance of the value. **Returns:** `_Self`
--- ## ImplicitlyCopyable
A marker trait to permit compiler to insert implicit calls to `__copyinit__` in order to make a copy of the object when needed. Conforming a type to `ImplicitlyCopyable` gives the Mojo language permission to implicitly insert a call to that types copy constructor whenever a borrowed instance of the type is passed or assigned where an owned value is required. Types that are expensive to copy, or where implicit copying could mask a logic error, typically should not be `ImplicitlyCopyable`. The `ImplicitlyCopyable` trait is a marker trait, meaning that it does not require a type to provide any additional methods or associated aliases to conform to this trait. However, all `ImplicitlyCopyable` types are required to conform to `Copyable`, which ensures there is only one definition for the logic of how a type is copied. **Note:** `ImplicitlyCopyable` should only be used to mark structs that may be copied implicitly. It should not be used as a trait bound (`T: ImplicitlyCopyable`) on functions or types, except in special circumstances. Generic code that may perform copies should always use the more general `T: Copyable` bound. This ensures that generic code is usable with all types that are copyable, regardless of whether they opt-in to implicit copying. ### Examples A type can opt-in to implicit copying by conforming to `ImplicitlyCopyable` (in the example below, the compiler also synthesizes a default field-wise `__copyinit__()` implementation, as the user didn't provide a definition): ``` @fieldwise_init struct Point(ImplicitlyCopyable) var x: Int var y: Int fn main(): var p = Point(5, 10) # Perform an implicit copy of `p var p2 = p ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## Movable
The Movable trait denotes a type whose value can be moved. Implement the `Movable` trait on `Foo` which requires the `__moveinit__` method: ```mojo struct Foo(Movable): fn __init__(out self): pass fn __moveinit__(out self, deinit existing: Self): print("moving") ``` You can now use the ^ suffix to move the object instead of copying it inside generic functions: ```mojo fn return_foo[T: Movable](var foo: T) -> T: return foo^ var foo = Foo() var res = return_foo(foo^) ``` ```plaintext moving ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self`
--- ## value
Defines core value traits. These are Mojo built-ins, so you don't need to import them. ## `comptime` values ### `ExplicitlyCopyable` `comptime ExplicitlyCopyable = Copyable` Deprecated alias for `Copyable`. **Deprecated:** Use `Copyable` or `ImplicitlyCopyable` instead. `Copyable` on its own no longer implies implicit copyability. ## Traits * [​`Copyable`](/mojo/stdlib/builtin/value/Copyable): The Copyable trait denotes a type whose value can be explicitly copied. * [​`Defaultable`](/mojo/stdlib/builtin/value/Defaultable): The `Defaultable` trait describes a type with a default constructor. * [​`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable): A marker trait to permit compiler to insert implicit calls to `__copyinit__` in order to make a copy of the object when needed. * [​`Movable`](/mojo/stdlib/builtin/value/Movable): The Movable trait denotes a type whose value can be moved. ## Functions * [​`materialize`](/mojo/stdlib/builtin/value/materialize): Explicitly materialize a compile-time parameter into a run-time value.
--- ## materialize
`materialize[T: AnyType, //, value: T](out result: T)` Explicitly materialize a compile-time parameter into a run-time value. **Parameters:** * ​T ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The type of the value to materialize. * ​value (`T`): The compile-time parameter value to materialize. **Returns:** `T`: The materialized run-time value.
--- ## Variadic
`struct Variadic` A namespace for variadic utilities. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `concat` `comptime concat[T: AnyTrait[AnyType], //, *Ts: Variadic[T]] = #kgen.variadic.concat<#kgen.param.decl.ref<"Ts"> : !kgen.variadic> T>>>` Represents the concatenation of multiple variadic sequences of types. #### Parameters * ​T (`AnyTrait`): The trait that types in the variadic sequences must conform to. * ​\*Ts ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): The variadic sequences to concatenate. ### `contains` `comptime contains[Trait: AnyTrait[AnyType], //, type: Trait, element_types: Variadic[Trait]] = #kgen.variadic.reduce<#kgen.variadic<{:i1 0}> : !kgen.variadic>, #kgen.param.decl.ref<"element_types"> : !kgen.variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait>, #kgen.gen<[cond(eq(:type !kgen.param<:!kgen.param<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> variadic_get(:variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> *(0,1), *(0,2))>, !kgen.param<:!kgen.param<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> type>), sugar_builtin(apply(:!lit.generator<() -> !lit.struct<@stdlib::@builtin::@bool::@Bool>> @stdlib::@sys::@intrinsics::@"_type_is_eq_parse_time[::AnyType,::AnyType]()"<:trait<@stdlib::@builtin::@anytype::@AnyType> !kgen.param<:!kgen.param<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> variadic_get(:variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> *(0,1), *(0,2))>, :trait<@stdlib::@builtin::@anytype::@AnyType> !kgen.param<:!kgen.param<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> type>>), {_mlir_value: i1 = eq(:type !kgen.param<:!kgen.param<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> variadic_get(:variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> *(0,1), *(0,2))>, !kgen.param<:!kgen.param<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait> type>)}), variadic_get(:variadic> *(0,0), 0))]> : !kgen.generator>, "VA": variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> Trait>, "idx": index>variadic>>>>[0]` Check if a type is contained in a variadic sequence. #### Parameters * ​Trait (`AnyTrait`): The trait that the types conform to. * ​type (`Trait`): The type to check for. * ​element\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): The variadic sequence of types to search. ### `empty_of_trait` `comptime empty_of_trait[T: AnyTrait[AnyType]]` Empty comptime variadic of type values. #### Parameters * ​T (`AnyTrait`): The trait that types in the variadic sequence must conform to. ### `empty_of_type` `comptime empty_of_type[T: AnyType]` Empty comptime variadic of values. #### Parameters * ​T ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The type of values in the variadic sequence. ### `reverse` `comptime reverse[T: AnyTrait[AnyType], //, *element_types: T] = #kgen.variadic.reduce<#kgen.variadic<> : !kgen.variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> T>, #kgen.param.decl.ref<"element_types"> : !kgen.variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> T>, #kgen.gen<#kgen.variadic.concat<#kgen.variadic<*(0,0), [variadic_get(:variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> T> *(0,1), add(mul(*(0,2), -1), #kgen.variadic.size<#kgen.param.index.ref<0, 1> : !kgen.variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> T>>, -1))]> : !kgen.variadic> T>>>> : !kgen.generator> T>, "VA": variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> T>, "idx": index>variadic<:!lit.anytrait<<@stdlib::@builtin::@anytype::@AnyType>> T>>>>` A wrapper to reverse a variadic sequence of types. #### Parameters * ​T (`AnyTrait`): The trait that the types conform to. * ​\*element\_types (`T`): The variadic sequence of types to reverse. ### `splat` `comptime splat[type: AnyType, count: Int] = #kgen.variadic.splat<:trait<@stdlib::@builtin::@anytype::@AnyType> *"type", #lit.struct.extract<:!lit.struct<@stdlib::@builtin::@int::@Int> count, "_mlir_value"> : index>` Splat a type into a variadic sequence. #### Parameters * ​type ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The type to splat. * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of times to splat the type. ### `types` `comptime types[T: AnyTrait[AnyType], //, *Ts: T] = Ts` Turn discrete type values (bound by `T`) into a single variadic. #### Parameters * ​T (`AnyTrait`): The trait that the types must conform to. * ​\*Ts (`T`): The types to collect into a variadic sequence. ### `TypesOfTrait` `comptime TypesOfTrait[T: AnyTrait[UnknownDestructibility]] = Variadic[T]` Represents a raw variadic sequence of types that satisfy the specified trait. #### Parameters * ​T (`AnyTrait`): The trait that types in the variadic sequence must conform to. ### `values` `comptime values[T: AnyType, //, *values_: T] = values_` Turn discrete values (bound by `T`) into a single variadic. #### Parameters * ​T ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The type of the values. * ​\*values\_ (`T`): The values to collect into a variadic sequence. ### `ValuesOfType` `comptime ValuesOfType[type: UnknownDestructibility] = Variadic[type]` Represents a raw variadic sequence of values of the specified type. #### Parameters * ​type ([`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility)): The type of values in the variadic sequence. ## Methods ### `size` `static size[T: UnknownDestructibility](seq: Variadic[T]) -> Int` Returns the length of a variadic sequence. **Parameters:** * ​T ([`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility)): The type of values in the sequence. **Args:** * ​seq ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): The variadic sequence to measure. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of the variadic sequence. `static size[T: AnyTrait[UnknownDestructibility]](seq: Variadic[T]) -> Int` Returns the length of a variadic sequence. **Parameters:** * ​T (`AnyTrait`): The trait that types in the sequence must conform to. **Args:** * ​seq ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): The variadic sequence of types to measure. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of the variadic sequence.
--- ## VariadicList
`@register_passable(trivial)` `struct VariadicList[type: AnyTrivialRegType]` A utility class to access homogeneous variadic function arguments. `VariadicList` is used when you need to accept variadic arguments where all arguments have the same type. Unlike `VariadicPack` (which is heterogeneous), `VariadicList` requires all elements to have the same concrete type. At runtime, `VariadicList` is treated as a homogeneous array. Because all the elements have the same type, each element has the same size and memory layout, so the compiler can generate code that works to access any index at runtime. Therefore, indexing into `VariadicList` can use runtime indices with regular `for` loops, whereas indexing into `VariadicPack` requires compile-time indices using `@parameter for` loops. For example, in the following function signature, `*args: Int` creates a `VariadicList` because it uses a single type `Int` instead of a variadic type parameter. The `*` before `args` indicates that `args` is a variadic argument, which means that the function can accept any number of arguments, but all arguments must have the same type `Int`. ```mojo fn sum_values(*args: Int) -> Int: var total = 0 # Can use regular for loop because args is a VariadicList for i in range(len(args)): total += args[i] # All elements are Int, so uniform access return total def main(): print(sum_values(1, 2, 3, 4, 5)) ``` ## Parameters * ​type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The type of the elements in the list. ## Fields * ​value (`Variadic[type]`): The underlying storage for the variadic list. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _VariadicListIter[type]` The iterator type for this variadic list. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `@implicit` `__init__(*value: type) -> Self` Constructs a VariadicList from a variadic list of arguments. **Args:** * ​\*value (`type`): The variadic argument list to construct the variadic list with. ### `__getitem__` `__getitem__[I: Indexer](self, idx: I) -> type` Gets a single element on the variadic list. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): A type that can be used as an index. **Args:** * ​idx (`I`): The index of the element to access on the list. **Returns:** `type`: The element on the list corresponding to the given index. ### `__len__` `__len__(self) -> Int` Gets the size of the list. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements on the variadic list. ### `__iter__` `__iter__(ref self) -> _VariadicListIter[type]` Iterate over the list. **Returns:** `_VariadicListIter`: An iterator to the start of the list.
--- ## VariadicListMem
`struct VariadicListMem[elt_is_mutable: Bool, //, element_type: AnyType, origin: Origin[elt_is_mutable], is_owned: Bool]` A utility class to access variadic function arguments of memory-only types that may have ownership. It exposes references to the elements in a way that can be enumerated. Each element may be accessed with `elt[]`. ## Parameters * ​elt\_is\_mutable ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): True if the elements of the list are mutable for an mut or owned argument. * ​element\_type ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The type of the elements in the list. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the underlying elements. * ​is\_owned ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the elements are owned by the list because they are passed as an 'var' argument. ## Fields * ​value (`Variadic[ref [origin] element_type]`): The underlying storage, a variadic list of references to elements of the given type. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `reference_type` `comptime reference_type = Pointer[element_type, origin]` The pointer type for references to elements. ## Methods ### `__moveinit__` `__moveinit__(out self, deinit existing: Self)` Moves constructor. **Args:** * ​existing (`Self`): The existing VariadicListMem. ### `__del__` `__del__(deinit self)` Destructor that releases elements if owned. ### `__getitem__` `__getitem__(self, idx: Int) -> ref [origin, origin_of(*[0,0])] element_type` Gets a single element on the variadic list. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to access on the list. **Returns:** `ref`: A low-level pointer to the element on the list corresponding to the given index. ### `consume_elements` `consume_elements[elt_handler: fn(idx: Int, var elt: element_type) capturing -> None](deinit self)` Consume the variadic list by transferring ownership of each element into the provided closure one at a time. This is only valid on 'owned' variadic lists. **Parameters:** * ​elt\_handler (`fn(idx: Int, var elt: element_type) capturing -> None`): A function that will be called for each element of the list. ### `__len__` `__len__(self) -> Int` Gets the size of the list. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements on the variadic list. ### `__iter__` `__iter__(self) -> _VariadicListMemIter[element_type, origin, self, is_owned]` Iterate over the list. **Returns:** `_VariadicListMemIter`: An iterator to the start of the list.
--- ## VariadicPack
`@register_passable` `struct VariadicPack[elt_is_mutable: Bool, //, is_owned: Bool, origin: Origin[elt_is_mutable], element_trait: AnyTrait[UnknownDestructibility], *element_types: element_trait]` A utility class to access heterogeneous variadic function arguments. `VariadicPack` is used when you need to accept variadic arguments where each argument can have a different type, but all types conform to a common trait. Unlike `VariadicList` (which is homogeneous), `VariadicPack` allows each element to have a different concrete type. `VariadicPack` is essentially a heterogeneous tuple that gets lowered to a struct at runtime. Because `VariadicPack` is a heterogeneous tuple (not an array), each element can have a different size and memory layout, which means the compiler needs to know the exact type of each element at compile time to generate the correct memory layout and access code. Therefore, indexing into `VariadicPack` requires compile-time indices using `@parameter for` loops, whereas indexing into `VariadicList` uses runtime indices. For example, in the following function signature, `*args: *ArgTypes` creates a `VariadicPack` because it uses a variadic type parameter `*ArgTypes` instead of a single type. The `*` before `ArgTypes` indicates that `ArgTypes` is a variadic type parameter, which means that the function can accept any number of arguments, and each argument can have a different type. This allows each argument to have a different type while all types must conform to the `Intable` trait. ```mojo fn count_many_things[*ArgTypes: Intable](*args: *ArgTypes) -> Int: var total = 0 # Must use @parameter for loop because args is a VariadicPack @parameter for i in range(args.__len__()): # Each args[i] has a different concrete type from *ArgTypes # The compiler generates specific code for each iteration total += Int(args[i]) return total def main(): print(count_many_things(5, 11.7, 12)) # Prints: 28 ``` ## Parameters * ​elt\_is\_mutable ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): True if the elements of the list are mutable for an mut or owned argument pack. * ​is\_owned ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the elements are owned by the pack. If so, the pack will release the elements when it is destroyed. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the underlying elements. * ​element\_trait (`AnyTrait`): The trait that each element of the pack conforms to. * ​\*element\_types (`element_trait`): The list of types held by the argument pack. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__del__` `__del__(deinit self)` Destructor that releases elements if owned. ### `__getitem__` `__getitem__[index: Int](self) -> ref [origin] element_types[index._mlir_value]` Return a reference to an element of the pack. **Parameters:** * ​index ([`Int`](/mojo/stdlib/builtin/int/Int)): The element of the pack to return. **Returns:** `ref`: A reference to the element. The Pointer's mutability follows the mutability of the pack argument convention. ### `consume_elements` `consume_elements[elt_handler: fn[idx: Int](var elt: element_types[idx._mlir_value]) capturing -> None](deinit self)` Consume the variadic pack by transferring ownership of each element into the provided closure one at a time. This is only valid on 'owned' variadic packs. **Parameters:** * ​elt\_handler (`fn[idx: Int](var elt: element_types[idx._mlir_value]) capturing -> None`): A function that will be called for each element of the pack. ### `__len__` `static __len__() -> Int` Return the VariadicPack length. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the variadic pack. `__len__(self) -> Int` Return the VariadicPack length. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the variadic pack.
--- ## variadics
Implements the VariadicList and VariadicPack types. These are Mojo built-ins, so you don't need to import them. ## Structs * [​`Variadic`](/mojo/stdlib/builtin/variadics/Variadic): A namespace for variadic utilities. * [​`VariadicList`](/mojo/stdlib/builtin/variadics/VariadicList): A utility class to access homogeneous variadic function arguments. * [​`VariadicListMem`](/mojo/stdlib/builtin/variadics/VariadicListMem): A utility class to access variadic function arguments of memory-only types that may have ownership. It exposes references to the elements in a way that can be enumerated. Each element may be accessed with `elt[]`. * [​`VariadicPack`](/mojo/stdlib/builtin/variadics/VariadicPack): A utility class to access heterogeneous variadic function arguments.
--- ## BitSet
`struct BitSet[size: Int]` A grow-only set storing non-negative integers efficiently using bits. Each integer element is represented by a single bit within an array of 64-bit words (`Int64`). This structure is optimized for: * **Compactness:** Uses 64 times less memory than `List[Bool]`. * **Speed:** Offers O(1) time complexity for `set`, `clear`, `test`, and `toggle` operations (single word load/store). Ideal for applications like data-flow analysis, graph algorithms, or any task requiring dense sets of small integers where memory and lookup speed are critical. ## Parameters * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum number of bits the bitset can store. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self)` Initializes an empty BitSet with zero capacity and size. `__init__(out self: BitSet[size], init: SIMD[DType.bool, size])` Initializes a BitSet with the given SIMD vector of booleans. **Args:** * ​init ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A SIMD vector of booleans to initialize the bitset with. **Returns:** [`BitSet`](/mojo/stdlib/collections/bitset/BitSet) ### `__bool__` `__bool__(self) -> Bool` Checks if the bitset is non-empty (contains at least one set bit). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if at least one bit is set, False otherwise. ### `__len__` `__len__(self) -> Int` Counts the total number of bits that are set to 1 in the bitset. Uses the efficient `pop_count` intrinsic for each underlying word. The complexity is proportional to the number of words used by the bitset's capacity (`_words_size`), not the logical size (`len`). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total count of set bits (population count). ### `set` `set(mut self, idx: Int)` Sets the bit at the specified index `idx` to 1. If `idx` is greater than or equal to the current logical size, the logical size is updated. Aborts if `idx` is negative or greater than or equal to the compile-time `size`. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The non-negative index of the bit to set (must be < `size`). ### `clear` `clear(mut self, idx: Int)` Clears the bit at the specified index `idx` (sets it to 0). Aborts if `idx` is negative or greater than or equal to the compile-time `size`. Does not change the logical size. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The non-negative index of the bit to clear (must be < `size`). ### `toggle` `toggle(mut self, idx: Int)` Toggles (inverts) the bit at the specified index `idx`. If the bit becomes 1 and `idx` is greater than or equal to the current logical size, the logical size is updated. Aborts if `idx` is negative or greater than or equal to the compile-time `size`. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The non-negative index of the bit to toggle (must be < `size`). ### `test` `test(self, idx: Int) -> Bool` Tests if the bit at the specified index `idx` is set (is 1). Aborts if `idx` is negative or greater than or equal to the compile-time `size`. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The non-negative index of the bit to test (must be < `size`). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the bit at `idx` is set, False otherwise. ### `clear_all` `clear_all(mut self)` Clears all bits in the set, resetting the logical size to 0. The allocated storage capacity remains unchanged. Equivalent to re-initializing the set with `Self()`. ### `toggle_all` `toggle_all(mut self)` Toggles (inverts) all bits in the set up to the compile-time `size`. ### `set_all` `set_all(mut self)` Sets all bits in the set up to the compile-time `size`. ### `union` `union(self, other: Self) -> Self` Returns a new bitset that is the union of `self` and `other`. **Args:** * ​other (`Self`): The bitset to union with. **Returns:** `Self`: A new bitset containing all elements from both sets. ### `intersection` `intersection(self, other: Self) -> Self` Returns a new bitset that is the intersection of `self` and `other`. **Args:** * ​other (`Self`): The bitset to intersect with. **Returns:** `Self`: A new bitset containing only the elements present in both sets. ### `difference` `difference(self, other: Self) -> Self` Returns a new bitset that is the difference of `self` and `other`. **Args:** * ​other (`Self`): The bitset to subtract from `self`. **Returns:** `Self`: A new bitset containing elements from `self` that are not in `other`. ### `write_to` `write_to(self, mut writer: T)` Writes a string representation of the set bits to the given writer. Outputs the indices of the set bits in ascending order, enclosed in curly braces and separated by commas (e.g., "{1, 5, 42}"). Uses efficient bitwise operations to find set bits without iterating through every possible bit. **Args:** * ​writer (`T`): The writer instance to output the representation to. ### `__repr__` `__repr__(self) -> String` Returns a developer-friendly string representation of the bitset. Currently equivalent to `__str__`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string showing the set bits (e.g., "{1, 5, 42}"). ### `__str__` `__str__(self) -> String` Returns a user-friendly string representation of the bitset. Formats the set bits as a comma-separated list within curly braces, like "{1, 5, 42}". Uses the `write_to` method internally. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string showing the set bits.
--- ## bitset (Bitset)
Provides a compact, grow-only set of non-negative integers. Optimized for space (1 bit per element) and speed (O(1) operations). Offers set/clear/test/toggle and fast population count. The underlying storage grows automatically but does not shrink unless `shrink_to_fit` is called (not implemented yet). Example: ```mojo from collections import BitSet var bs = BitSet[128]() # 128-bit set, all clear bs.set(42) # Mark value 42 as present. if bs.test(42): # Check membership. print("hit") # Prints "hit". bs.clear(42) # Remove 42. print(len(bs)) # Prints 0. ``` ## Structs * [​`BitSet`](/mojo/stdlib/collections/bitset/BitSet): A grow-only set storing non-negative integers efficiently using bits.
--- ## CountTuple
`struct CountTuple[V: KeyElement]` A tuple representing a value and its count in a `Counter`. ## Parameters * ​V ([`KeyElement`](/mojo/stdlib/collections/dict/#keyelement)): The value in the `Counter`. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True if V.__copyinit__is_trivial else V.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = True if V.__del__is_trivial else V.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True if V.__moveinit__is_trivial else V.__moveinit__is_trivial` ## Methods ### `__init__` `__init__(out self, value: V, count: UInt)` Create a new `CountTuple`. **Args:** * ​value (`V`): The value in the `Counter`. * ​count ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The count of the value in the `Counter`. ### `__getitem__` `__getitem__(self, idx: Int) -> Variant[V, Int]` Get an element in the `CountTuple`. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The element to return. **Returns:** `Variant`: The value if `idx` is `0` and the count if `idx` is `1`. ### `__lt__` `__lt__(self, other: Self) -> Bool` Compare two `CountTuple`s by count, then by value. **Args:** * ​other (`Self`): The other `CountTuple` to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if this `CountTuple` is less than the other, `False` otherwise. ### `__eq__` `__eq__(self, other: Self) -> Bool` Compare two `CountTuple`s for equality. **Args:** * ​other (`Self`): The other `CountTuple` to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if the two `CountTuple`s are equal, `False` otherwise.
--- ## Counter
`struct Counter[V: KeyElement, H: Hasher = default_hasher]` A container for counting hashable items. The value type must be specified statically, unlike a Python `Counter`, which can accept arbitrary value types. The value type must implement the `KeyElement` trait, as its values are stored in a dictionary as keys. Usage: ```mojo from collections import Counter var c = Counter[String]("a", "a", "a", "b", "b", "c", "d", "c", "c") print(c["a"]) # prints 3 print(c["b"]) # prints 2 ``` ## Parameters * ​V ([`KeyElement`](/mojo/stdlib/collections/dict/#keyelement)): The value type to be counted. Currently must be `KeyElement`. * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The type of the hasher in the underlying dictionary. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _DictKeyIter[V, Int, H, iterable_origin]` The iterator type for this counter. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `__init__(out self)` Create a new, empty `Counter` object. `__init__(out self, var *values: V)` Create a new `Counter` from a list of values. Usage: ```mojo from collections import Counter var c = Counter[String]("a", "a", "a", "b", "b", "c", "d", "c", "c") print(c["a"]) # print 3 print(c["b"]) # print 2 ``` **Args:** * ​\*values (`V`): A list of values to count. `__init__(out self, items: List[V])` Create a `Counter` from an input iterable. Usage: ```mojo from collections import Counter var c = Counter[String](["a", "a", "a", "b", "b", "c", "d", "c", "c"]) print(c["a"]) # prints 3 print(c["b"]) # prints 2 ``` **Args:** * ​items ([`List`](/mojo/stdlib/collections/list/List)): A list of items to count. ### `__bool__` `__bool__(self) -> Bool` Check if the `Counter` is empty or not. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `False` if the `Counter` is empty, `True` otherwise. ### `__getitem__` `__getitem__(self, key: V) -> Int` Get the count of a key. **Args:** * ​key (`V`): The key to get the count of. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The count of the key. ### `__setitem__` `__setitem__(mut self, value: V, count: Int)` Set a value in the keyword `Counter` by key. **Args:** * ​value (`V`): The value to associate with the specified count. * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): The count to store in the `Counter`. ### `__neg__` `__neg__(self) -> Self` Subtract from an empty `Counter`. Strips positive and zero counts, and flips the sign on negative counts. **Returns:** `Self`: A new `Counter` with stripped counts and negative counts. ### `__pos__` `__pos__(self) -> Self` Return a shallow copy of the `Counter`, stripping non-positive counts. **Returns:** `Self`: A shallow copy of the `Counter`. ### `__eq__` `__eq__(self, other: Self) -> Bool` Check if all counts agree. Missing counts are treated as zero. **Args:** * ​other (`Self`): The other `Counter` to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if the two `Counter`s are equal, `False` otherwise. ### `__contains__` `__contains__(self, key: V) -> Bool` Check if a given key is in the `Counter` or not. **Args:** * ​key (`V`): The key to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if there key exists in the `Counter`, `False` otherwise. ### `__add__` `__add__(self, other: Self) -> Self` Add counts from two `Counter`s. **Args:** * ​other (`Self`): The other `Counter` to add to this `Counter`. **Returns:** `Self`: A new `Counter` with the counts from both `Counter`s added together. ### `__sub__` `__sub__(self, other: Self) -> Self` Subtract counts, but keep only results with positive counts. **Args:** * ​other (`Self`): The other `Counter` to subtract from this `Counter`. **Returns:** `Self`: A new `Counter` with the counts from the other `Counter` subtracted from this `Counter`. ### `__and__` `__and__(self, other: Self) -> Self` Intersection: keep common elements with the minimum count. **Args:** * ​other (`Self`): The other `Counter` to intersect with. **Returns:** `Self`: A new `Counter` with the common elements and the minimum count of the two `Counter`s. ### `__or__` `__or__(self, other: Self) -> Self` Union: keep all elements with the maximum count. **Args:** * ​other (`Self`): The other `Counter` to union with. **Returns:** `Self`: A new `Counter` with all elements and the maximum count of the two `Counter`s. ### `__iadd__` `__iadd__(mut self, other: Self)` Add counts from another `Counter` to this `Counter`. **Args:** * ​other (`Self`): The other `Counter` to add to this `Counter`. ### `__isub__` `__isub__(mut self, other: Self)` Subtract counts from another `Counter` from this `Counter`, but kee only results with positive counts. **Args:** * ​other (`Self`): The other `Counter` to subtract from this `Counter`. ### `__iand__` `__iand__(mut self, other: Self)` Intersection: keep common elements with the minimum count. **Args:** * ​other (`Self`): The other `Counter` to intersect with. ### `__ior__` `__ior__(mut self, other: Self)` Union: keep all elements with the maximum count. **Args:** * ​other (`Self`): The other `Counter` to union with. ### `fromkeys` `static fromkeys(keys: List[V], value: Int) -> Self` Create a new `Counter` from a list of keys and a default value. **Args:** * ​keys ([`List`](/mojo/stdlib/collections/list/List)): The keys to create the `Counter` from. * ​value ([`Int`](/mojo/stdlib/builtin/int/Int)): The default value to associate with each key. **Returns:** `Self`: A new `Counter` with the keys and default value. ### `__iter__` `__iter__(ref self) -> _DictKeyIter[V, Int, H, self_is_origin]` Iterate over the `Counter`'s keys as immutable references. **Returns:** `_DictKeyIter`: An iterator of immutable references to the `Counter` values. ### `__len__` `__len__(self) -> Int` Returns the number of elements currently stored in the `Counter`. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the `Counter`. ### `le` `le(self, other: Self) -> Bool` Check if all counts are less than or equal to those in the other `Counter`. Note that since we check that *all* counts satisfy the condition, this comparison does not make `Counter`s totally ordered. **Args:** * ​other (`Self`): The other `Counter` to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if all counts are less than or equal to the other `Counter`, `False` otherwise. ### `lt` `lt(self, other: Self) -> Bool` Check if all counts are less than those in the other `Counter`. Note that since we check that *all* counts satisfy the condition, this comparison does not make `Counter`s totally ordered. **Args:** * ​other (`Self`): The other `Counter` to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if all counts are less than in the other `Counter`, `False` otherwise. ### `gt` `gt(self, other: Self) -> Bool` Check if all counts are greater than those in the other `Counter`. Note that since we check that *all* counts satisfy the condition, this comparison does not make `Counter`s totally ordered. **Args:** * ​other (`Self`): The other `Counter` to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if all counts are greater than in the other `Counter`, `False` otherwise. ### `ge` `ge(self, other: Self) -> Bool` Check if all counts are greater than or equal to those in the other `Counter`. Note that since we check that *all* counts satisfy the condition, this comparison does not make `Counter`s totally ordered. **Args:** * ​other (`Self`): The other `Counter` to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if all counts are greater than or equal to the other `Counter`, `False` otherwise. ### `get` `get(self, value: V) -> Optional[Int]` Get a value from the `Counter`. **Args:** * ​value (`V`): The value to search for in the `Counter`. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): An optional value containing a copy of the value if it was present, otherwise an empty `Optional`. `get(self, value: V, default: Int) -> Int` Get a value from the `Counter`. **Args:** * ​value (`V`): The value to search for in the `Counter`. * ​default ([`Int`](/mojo/stdlib/builtin/int/Int)): Default count to return. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): A copy of the value if it was present, otherwise default. ### `pop` `pop(mut self, value: V) -> Int` Remove a value from the `Counter` by value. **Args:** * ​value (`V`): The value to remove from the `Counter`. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value associated with the key, if it was in the `Counter`. **Raises:** "KeyError" if the key was not present in the `Counter`. `pop(mut self, value: V, var default: Int) -> Int` Remove a value from the `Counter` by value. **Args:** * ​value (`V`): The value to remove from the `Counter`. * ​default ([`Int`](/mojo/stdlib/builtin/int/Int)): Optionally provide a default value to return if the value was not found instead of raising. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value associated with the key, if it was in the `Counter`. If it wasn't, return the provided default value instead. ### `keys` `keys(ref self) -> _DictKeyIter[V, Int, H, origin_of(self_is_origin._data)]` Iterate over the `Counter`'s keys as immutable references. **Returns:** `_DictKeyIter`: An iterator of immutable references to the `Counter` keys. ### `values` `values(ref self) -> _DictValueIter[V, Int, H, origin_of(self_is_origin._data)]` Iterate over the `Counter`'s values as references. **Returns:** [`_DictValueIter`](/mojo/stdlib/collections/dict/_DictValueIter): An iterator of references to the `Counter` values. ### `items` `items(self) -> _DictEntryIter[V, Int, H, origin_of(self._data)]` Iterate over the `Counter`'s entries as immutable references. **Returns:** [`_DictEntryIter`](/mojo/stdlib/collections/dict/_DictEntryIter): An iterator of immutable references to the `Counter` entries. ### `clear` `clear(mut self)` Remove all elements from the `Counter`. ### `popitem` `popitem(mut self) -> CountTuple[V]` Remove and return an arbitrary (key, value) pair from the `Counter`. **Returns:** `CountTuple`: A `CountTuple` containing the key and value of the removed item. **Raises:** "KeyError" if the `Counter` is empty. ### `total` `total(self) -> UInt` Return the total of all counts in the `Counter`. **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): The total of all counts in the `Counter`. ### `most_common` `most_common(self, n: UInt) -> List[CountTuple[V]]` Return a list of the `n` most common elements and their counts from the most common to the least. **Args:** * ​n ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): The number of most common elements to return. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A list of the `n` most common elements and their counts. ### `elements` `elements(self) -> List[V]` Return an iterator over elements repeating each as many times as its count. **Returns:** [`List`](/mojo/stdlib/collections/list/List): An iterator over the elements in the `Counter`. ### `update` `update(mut self, other: Self)` Update the `Counter`, like `Dict.update()` but add counts instead of replacing them. **Args:** * ​other (`Self`): The `Counter` to update this `Counter` with. ### `subtract` `subtract(mut self, other: Self)` Subtract counts. Both inputs and outputs may be zero or negative. **Args:** * ​other (`Self`): The `Counter` to subtract from this `Counter`.
--- ## counter (Counter)
Defines the `Counter` type. You can import these APIs from the `collections` package. For example: ```mojo from collections import Counter ``` ## Structs * [​`Counter`](/mojo/stdlib/collections/counter/Counter): A container for counting hashable items. * [​`CountTuple`](/mojo/stdlib/collections/counter/CountTuple): A tuple representing a value and its count in a `Counter`.
--- ## Deque
`struct Deque[ElementType: Copyable]` Implements a double-ended queue. It supports pushing and popping from both ends in O(1) time resizing the underlying storage as needed. ## Parameters * ​ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the traits `Copyable`. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `default_capacity` `comptime default_capacity = 64` The default capacity of the deque: must be the power of 2. ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _DequeIter[ElementType, iterable_origin]` The iterator type for this deque. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `__init__(out self, *, var elements: Optional[List[ElementType]] = None, capacity: Int = Deque[ElementType].default_capacity, min_capacity: Int = Deque[ElementType].default_capacity, maxlen: Int = -1, shrink: Bool = True)` Constructs a deque. **Args:** * ​elements ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The optional list of initial deque elements. * ​capacity ([`Int`](/mojo/stdlib/builtin/int/Int)): The initial capacity of the deque. * ​min\_capacity ([`Int`](/mojo/stdlib/builtin/int/Int)): The minimum allowed capacity of the deque when shrinking. * ​maxlen ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum allowed capacity of the deque when growing. * ​shrink ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Should storage be de-allocated when not needed. `__init__(out self, var *values: ElementType, *, __list_literal__: Tuple[] = Tuple[]())` Constructs a deque from the given values. **Args:** * ​\*values (`ElementType`): The values to populate the deque with. * ​**list\_literal** ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tell Mojo to use this method for list literals. `__init__(out self, *, var elements: VariadicListMem[ElementType, origin, is_owned])` Constructs a deque from the given values. **Args:** * ​elements ([`VariadicListMem`](/mojo/stdlib/builtin/variadics/VariadicListMem)): The values to populate the deque with. ### `__copyinit__` `__copyinit__(out self, other: Self)` Creates a deepcopy of the given deque. **Args:** * ​other (`Self`): The deque to copy. ### `__del__` `__del__(deinit self)` Destroys all elements in the deque and free its memory. ### `__bool__` `__bool__(self) -> Bool` Checks whether the deque has any elements or not. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `False` if the deque is empty, `True` if there is at least one element. ### `__getitem__` `__getitem__(ref self, idx: Int) -> ref [self] ElementType` Gets the deque element at the given index. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element. **Returns:** `ref`: A reference to the element at the given index. ### `__eq__` `__eq__[T: Equatable & Copyable, //](self: Deque[T], other: Deque[T]) -> Bool` Checks if two deques are equal. **Parameters:** * ​T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the trait `Equatable`. **Args:** * ​other ([`Deque`](/mojo/stdlib/collections/deque/Deque)): The deque to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if the deques are equal, `False` otherwise. ### `__ne__` `__ne__[T: Equatable & Copyable, //](self: Deque[T], other: Deque[T]) -> Bool` Checks if two deques are not equal. **Parameters:** * ​T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the trait `Equatable`. **Args:** * ​other ([`Deque`](/mojo/stdlib/collections/deque/Deque)): The deque to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if the deques are not equal, `False` otherwise. ### `__contains__` `__contains__[T: Equatable & Copyable, //](self: Deque[T], value: T) -> Bool` Verify if a given value is present in the deque. **Parameters:** * ​T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the trait `Equatable`. **Args:** * ​value (`T`): The value to find. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the value is contained in the deque, False otherwise. ### `__add__` `__add__(self, other: Self) -> Self` Concatenates self with other and returns the result as a new deque. **Args:** * ​other (`Self`): Deque whose elements will be appended to the elements of self. **Returns:** `Self`: The newly created deque with the properties of `self`. ### `__mul__` `__mul__(self, n: Int) -> Self` Concatenates `n` deques of `self` and returns a new deque. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The multiplier number. **Returns:** `Self`: The new deque. ### `__iadd__` `__iadd__(mut self, other: Self)` Appends the elements of other deque into self. **Args:** * ​other (`Self`): Deque whose elements will be appended to self. ### `__imul__` `__imul__(mut self, n: Int)` Concatenates self `n` times in place. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The multiplier number. ### `__iter__` `__iter__(ref self) -> _DequeIter[ElementType, self_is_origin]` Iterates over elements of the deque, returning the references. **Returns:** `_DequeIter`: An iterator of the references to the deque elements. ### `__reversed__` `__reversed__(ref self) -> _DequeIter[ElementType, self_is_origin, False]` Iterate backwards over the deque, returning the references. **Returns:** `_DequeIter`: A reversed iterator of the references to the deque elements. ### `__len__` `__len__(self) -> Int` Gets the number of elements in the deque. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the deque. ### `write_to` `write_to[T: Representable & Copyable](self: Deque[T], mut writer: T)` Writes `my_deque.__str__()` to a `Writer`. **Parameters:** * ​T ([`Representable`](/mojo/stdlib/builtin/repr/Representable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the Deque elements. Must implement the trait `Representable`. **Args:** * ​writer (`T`): The object to write to. ### `__str__` `__str__[T: Representable & Copyable, //](self: Deque[T]) -> String` Returns a string representation of a `Deque`. Note that since we can't condition methods on a trait yet, the way to call this method is a bit special. Here is an example below: ```mojo my_deque = Deque[Int](1, 2, 3) print(my_deque.__str__()) ``` When the compiler supports conditional methods, then a simple `String(my_deque)` will be enough. **Parameters:** * ​T ([`Representable`](/mojo/stdlib/builtin/repr/Representable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the trait `Representable`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the deque. ### `__repr__` `__repr__[T: Representable & Copyable, //](self: Deque[T]) -> String` Returns a string representation of a `Deque`. Note that since we can't condition methods on a trait yet, the way to call this method is a bit special. Here is an example below: ```mojo my_deque = Deque[Int](1, 2, 3) print(my_deque.__repr__()) ``` When the compiler supports conditional methods, then a simple `repr(my_deque)` will be enough. **Parameters:** * ​T ([`Representable`](/mojo/stdlib/builtin/repr/Representable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the trait `Representable`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the deque. ### `append` `append(mut self, var value: ElementType)` Appends a value to the right side of the deque. **Args:** * ​value (`ElementType`): The value to append. ### `appendleft` `appendleft(mut self, var value: ElementType)` Appends a value to the left side of the deque. **Args:** * ​value (`ElementType`): The value to append. ### `clear` `clear(mut self)` Removes all elements from the deque leaving it with length 0. Resets the underlying storage capacity to `_min_capacity`. ### `count` `count[T: Equatable & Copyable, //](self: Deque[T], value: T) -> Int` Counts the number of occurrences of a `value` in the deque. **Parameters:** * ​T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the trait `Equatable`. **Args:** * ​value (`T`): The value to count. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of occurrences of the value in the deque. ### `extend` `extend(mut self, var values: List[ElementType])` Extends the right side of the deque by consuming elements of the list argument. **Args:** * ​values ([`List`](/mojo/stdlib/collections/list/List)): List whose elements will be added at the right side of the deque. ### `extendleft` `extendleft(mut self, var values: List[ElementType])` Extends the left side of the deque by consuming elements from the list argument. Acts as series of left appends resulting in reversed order of elements in the list argument. **Args:** * ​values ([`List`](/mojo/stdlib/collections/list/List)): List whose elements will be added at the left side of the deque. ### `index` `index[T: Equatable & Copyable, //](self: Deque[T], value: T, start: Int = 0, stop: Optional[Int] = None) -> Int` Returns the index of the first occurrence of a `value` in a deque restricted by the range given the `start` and `stop` bounds. **Parameters:** * ​T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the `Equatable` trait. **Args:** * ​value (`T`): The value to search for. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The starting index of the search, treated as a slice index (defaults to 0). * ​stop ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The ending index of the search, treated as a slice index (defaults to None, which means the end of the deque). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The index of the first occurrence of the value in the deque. **Raises:** ValueError: If the value is not found in the deque. ### `insert` `insert(mut self, idx: Int, var value: ElementType)` Inserts the `value` into the deque at position `idx`. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The position to insert the value into. * ​value (`ElementType`): The value to insert. **Raises:** IndexError: If deque is already at its maximum size. ### `remove` `remove[T: Equatable & Copyable, //](mut self: Deque[T], value: T)` Removes the first occurrence of the `value`. **Parameters:** * ​T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the deque. Must implement the `Equatable` trait. **Args:** * ​value (`T`): The value to remove. **Raises:** ValueError: If the value is not found in the deque. ### `peek` `peek(self) -> ElementType` Inspect the last (rightmost) element of the deque without removing it. **Returns:** `ElementType`: The last (rightmost) element of the deque. **Raises:** IndexError: If the deque is empty. ### `peekleft` `peekleft(self) -> ElementType` Inspect the first (leftmost) element of the deque without removing it. **Returns:** `ElementType`: The first (leftmost) element of the deque. **Raises:** IndexError: If the deque is empty. ### `pop` `pop(mut self) -> ElementType` Removes and returns the element from the right side of the deque. **Returns:** `ElementType`: The popped value. **Raises:** IndexError: If the deque is empty. ### `popleft` `popleft(mut self) -> ElementType` Removes and returns the element from the left side of the deque. **Returns:** `ElementType`: The popped value. **Raises:** IndexError: If the deque is empty. ### `reverse` `reverse(mut self)` Reverses the elements of the deque in-place. ### `rotate` `rotate(mut self, n: Int = 1)` Rotates the deque by `n` steps. If `n` is positive, rotates to the right. If `n` is negative, rotates to the left. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of steps to rotate the deque (defaults to 1).
--- ## deque (Deque)
Defines the Deque type. You can import these APIs from the `collections` package. Examples: ```mojo from collections import Deque ``` ## Structs * [​`Deque`](/mojo/stdlib/collections/deque/Deque): Implements a double-ended queue.
--- ## Dict
`struct Dict[K: KeyElement, V: Copyable, H: Hasher = default_hasher]` A container that stores key-value pairs. The `Dict` type is Mojo's primary associative collection, similar to Python's `dict` (dictionary). Unlike a `List`, which stores elements by index, a `Dict` stores values associated with unique keys, which enables fast lookups, insertions, and deletions. You can create a `Dict` in several ways: ```mojo # Empty dictionary var empty_dict = Dict[String, Int]() # Dictionary literal syntax var scores = {"Alice": 95, "Bob": 87, "Charlie": 92} # Pre-allocated capacity (must be power of 2, >= 8) var large_dict = Dict[String, Int](power_of_two_initial_capacity=64) # From separate key and value lists var keys = ["red", "green", "blue"] var values = [255, 128, 64] var colors = Dict[String, Int]() for key, value in zip(keys, values): colors[key] = value ``` Be aware of the following characteristics: * **Type safety**: Both keys and values must be homogeneous types, determined at compile time. This is more restrictive than Python dictionaries but provides better performance: ```mojo var string_to_int = {"count": 42} # Dict[String, Int] var int_to_string = {1: "one"} # Dict[Int, String] var mixed = {"key": 1, 2: "val"} # Error! Keys must be same type ``` However, you can get around this by defining your dictionary key and/or value type as [`Variant`](/mojo/stdlib/utils/variant/Variant). This is a discriminated union type, meaning it can store any number of different types that can vary at runtime. * **Value semantics**: A `Dict` is value semantic by default. Copying a `Dict` creates a deep copy of all key-value pairs. To avoid accidental copies, `Dict` is not implicitly copyable—you must explicitly copy it using the `.copy()` method. ```mojo var dict1 = {"a": 1, "b": 2} # var dict2 = dict1 # Error: Dict is not implicitly copyable var dict2 = dict1.copy() # Deep copy dict2["c"] = 3 print(dict1.__str__()) # => {"a": 1, "b": 2} print(dict2.__str__()) # => {"a": 1, "b": 2, "c": 3} ``` This is different from Python, where assignment creates a reference to the same dictionary. For more information, read about [value semantics](/mojo/manual/values/value-semantics). * **Iteration uses immutable references**: When iterating over keys, values, or items, you get immutable references unless you specify `ref` or `var`: ```mojo var inventory = {"apples": 10, "bananas": 5} # Default behavior creates immutable (read-only) references for value in inventory.values(): value += 1 # error: expression must be mutable # Using `ref` gets mutable (read-write) references for ref value in inventory.values(): value += 1 # Modify inventory values in-place print(inventory.__str__()) # => {"apples": 11, "bananas": 6} # Using `var` gets an owned copy of the value for var key in inventory.keys(): inventory[key] += 1 # Modify inventory values in-place print(inventory.__str__()) # => {"apples": 12, "bananas": 7} ``` Note that indexing into a `Dict` with a key that's a reference to the key owned by the `Dict` produces a confusing error related to [argument exclusivity](/mojo/manual/values/ownership#argument-exclusivity). Using `var key` in the previous example creates an owned copy of the key, avoiding the error. * **KeyError handling**: Directly accessing values with the `[]` operator will raise `KeyError` if the key is not found: ```mojo var phonebook = {"Alice": "555-0101", "Bob": "555-0102"} print(phonebook["Charlie"]) # => KeyError: "Charlie" ``` For safe access, you should instead use `get()`: ```mojo var phonebook = {"Alice": "555-0101", "Bob": "555-0102"} var phone = phonebook.get("Charlie") print(phone.__str__()) if phone else print('phone not found') ``` Examples: ```mojo var phonebook = {"Alice": "555-0101", "Bob": "555-0102"} # Add/update entries phonebook["Charlie"] = "555-0103" # Add new entry phonebook["Alice"] = "555-0199" # Update existing entry # Access directly (unsafe and raises KeyError if key not found) print(phonebook["Alice"]) # => 555-0199 # Access safely var phone = phonebook.get("David") # Returns Optional type print(phone.or_else("phone not found!")) # Access safely with default value phone = phonebook.get("David", "555-0000") print(phone.__str__()) # => '555-0000' # Check for keys if "Bob" in phonebook: print("Found Bob") # Remove (pop) entries print(phonebook.pop("Charlie")) # Remove and return: "555-0103" print(phonebook.pop("Unknown", "N/A")) # Pop with default # Iterate over a dictionary for key in phonebook.keys(): print("Key:", key) for value in phonebook.values(): print("Value:", value) for item in phonebook.items(): print(item.key, "=>", item.value) for var key in phonebook: print(key, "=>", phonebook[key]) # Number of key-value pairs print('len:', len(phonebook)) # => len: 2 # Dictionary operations var backup = phonebook.copy() # Explicit copy phonebook.clear() # Remove all entries # Merge dictionaries var more_numbers = {"David": "555-0104", "Eve": "555-0105"} backup.update(more_numbers) # Merge in-place var combined = backup | more_numbers # Create new merged dict print(combined.__str__()) ``` ## Parameters * ​K ([`KeyElement`](/mojo/stdlib/collections/dict/#keyelement)): The type of keys stored in the dictionary. * ​V ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of values stored in the dictionary. * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The type of hasher used to hash the keys. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `EMPTY` `comptime EMPTY = -1` Marker for an empty slot in the hash table. ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _DictKeyIter[K, V, H, iterable_origin]` The iterator type for this dictionary. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ### `REMOVED` `comptime REMOVED = -2` Marker for a removed slot in the hash table. ## Methods ### `__init__` `__init__(out self)` Initialize an empty dictiontary. `__init__(out self, *, power_of_two_initial_capacity: Int)` Initialize an empty dictiontary with a pre-reserved initial capacity. Examples: ```mojo var x = Dict[Int, Int](power_of_two_initial_capacity = 1024) # Insert (2/3 of 1024) entries without reallocation. ``` **Args:** * ​power\_of\_two\_initial\_capacity ([`Int`](/mojo/stdlib/builtin/int/Int)): At least 8, has to be a power of two. `__init__(out self, var keys: List[K], var values: List[V], __dict_literal__: Tuple[])` Constructs a dictionary from the given keys and values. **Args:** * ​keys ([`List`](/mojo/stdlib/collections/list/List)): The list of keys to build the dictionary with. * ​values ([`List`](/mojo/stdlib/collections/list/List)): The corresponding values to pair with the keys. * ​**dict\_literal** ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tell Mojo to use this method for dict literals. ### `__copyinit__` `__copyinit__(out self, existing: Self)` Copy an existing dictiontary. **Args:** * ​existing (`Self`): The existing dict. ### `__bool__` `__bool__(self) -> Bool` Check if the dictionary is empty or not. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `False` if the dictionary is empty, `True` if there is at least one element. ### `__getitem__` `__getitem__(ref self, key: K) -> ref [origin_of($1._entries._value.value)] V` Retrieve a value out of the dictionary. **Args:** * ​key (`K`): The key to retrieve. **Returns:** `ref`: The value associated with the key, if it's present. **Raises:** "KeyError" if the key isn't present. ### `__setitem__` `__setitem__(mut self, var key: K, var value: V)` Set a value in the dictionary by key. **Args:** * ​key (`K`): The key to associate with the specified value. * ​value (`V`): The data to store in the dictionary. ### `__contains__` `__contains__(self, key: K) -> Bool` Check if a given key is in the dictionary or not. **Args:** * ​key (`K`): The key to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the key exists in the dictionary, False otherwise. ### `__or__` `__or__(self, other: Self) -> Self` Merge self with other and return the result as a new dict. **Args:** * ​other (`Self`): The dictionary to merge with. **Returns:** `Self`: The result of the merge. ### `__ior__` `__ior__(mut self, other: Self)` Merge self with other in place. **Args:** * ​other (`Self`): The dictionary to merge with. ### `fromkeys` `static fromkeys(keys: List[K], value: V) -> Self` Create a new dictionary with keys from list and values set to value. **Args:** * ​keys ([`List`](/mojo/stdlib/collections/list/List)): The keys to set. * ​value (`V`): The value to set. **Returns:** `Self`: The new dictionary. `static fromkeys(keys: List[K], value: Optional[V] = None) -> Dict[K, Optional[V], H]` Create a new dictionary with keys from list and values set to value. **Args:** * ​keys ([`List`](/mojo/stdlib/collections/list/List)): The keys to set. * ​value ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The value to set. **Returns:** [`Dict`](/mojo/stdlib/collections/dict/Dict): The new dictionary. ### `__iter__` `__iter__(ref self) -> _DictKeyIter[K, V, H, self_is_origin]` Iterate over the dict's keys as immutable references. **Returns:** `_DictKeyIter`: An iterator of immutable references to the dictionary keys. ### `__reversed__` `__reversed__(ref self) -> _DictKeyIter[K, V, H, self_is_origin, False]` Iterate backwards over the dict keys, returning immutable references. **Returns:** `_DictKeyIter`: A reversed iterator of immutable references to the dict keys. ### `__len__` `__len__(self) -> Int` The number of elements currently stored in the dictionary. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements currently stored in the dictionary. ### `__repr__` `__repr__(self) -> String` Returns a string representation of a `Dict`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the Dict. ### `__str__` `__str__(self) -> String` Returns a string representation of a `Dict`. Examples: ```mojo var my_dict = Dict[Int, Float64]() my_dict[1] = 1.1 my_dict[2] = 2.2 dict_as_string = String(my_dict) print(dict_as_string) # prints "{1: 1.1, 2: 2.2}" ``` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the Dict. ### `write_to` `write_to(self, mut writer: T)` Write `my_list.__str__()` to a `Writer`. **Constraints:** `K` must conform to `Representable`. `V` must conform to `Representable`. **Args:** * ​writer (`T`): The object to write to. ### `find` `find(self, key: K) -> Optional[V]` Find a value in the dictionary by key. **Args:** * ​key (`K`): The key to search for in the dictionary. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): An optional value containing a copy of the value if it was present, otherwise an empty Optional. ### `get` `get(self, key: K) -> Optional[V]` Get a value from the dictionary by key. **Args:** * ​key (`K`): The key to search for in the dictionary. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): An optional value containing a copy of the value if it was present, otherwise an empty Optional. `get(self, key: K, var default: V) -> V` Get a value from the dictionary by key. **Args:** * ​key (`K`): The key to search for in the dictionary. * ​default (`V`): Default value to return. **Returns:** `V`: A copy of the value if it was present, otherwise default. ### `pop` `pop(mut self, key: K, var default: V) -> V` Remove a value from the dictionary by key. **Args:** * ​key (`K`): The key to remove from the dictionary. * ​default (`V`): A default value to return if the key was not found instead of raising. **Returns:** `V`: The value associated with the key, if it was in the dictionary. If it wasn't, return the provided default value instead. `pop(mut self, key: K) -> V` Remove a value from the dictionary by key. **Args:** * ​key (`K`): The key to remove from the dictionary. **Returns:** `V`: The value associated with the key, if it was in the dictionary. Raises otherwise. **Raises:** "KeyError" if the key was not present in the dictionary. ### `popitem` `popitem(mut self) -> DictEntry[K, V, H]` Remove and return a (key, value) pair from the dictionary. Notes: Pairs are returned in LIFO order. popitem() is useful to destructively iterate over a dictionary, as often used in set algorithms. If the dictionary is empty, calling popitem() raises a KeyError. **Returns:** `DictEntry`: Last dictionary item **Raises:** "KeyError" if the dictionary is empty. ### `keys` `keys(ref self) -> _DictKeyIter[K, V, H, self_is_origin]` Iterate over the dict's keys as immutable references. **Returns:** `_DictKeyIter`: An iterator of immutable references to the dictionary keys. ### `values` `values(ref self) -> _DictValueIter[K, V, H, self_is_origin]` Iterate over the dict's values as references. **Returns:** [`_DictValueIter`](/mojo/stdlib/collections/dict/_DictValueIter): An iterator of references to the dictionary values. ### `items` `items(ref self) -> _DictEntryIter[K, V, H, self_is_origin]` Iterate over the dict's entries as immutable references. Examples: ```mojo var my_dict = Dict[String, Int]() my_dict["a"] = 1 my_dict["b"] = 2 for e in my_dict.items(): print(e.key, e.value) ``` Notes: These can't yet be unpacked like Python dict items, but you can access the key and value as attributes. **Returns:** [`_DictEntryIter`](/mojo/stdlib/collections/dict/_DictEntryIter): An iterator of immutable references to the dictionary entries. ### `take_items` `take_items(mut self) -> _TakeDictEntryIter[K, V, H, self]` Iterate over the dict's entries and move them out of the dictionary effectively draining the dictionary. Examples: ```mojo var my_dict = Dict[String, Int]() my_dict["a"] = 1 my_dict["b"] = 2 for entry in my_dict.take_items(): print(entry.key, entry.value) print(len(my_dict)) # prints 0 ``` **Returns:** `_TakeDictEntryIter`: An iterator of mutable references to the dictionary entries that moves them out of the dictionary. ### `update` `update(mut self, other: Self, /)` Update the dictionary with the key/value pairs from other, overwriting existing keys. Notes: The argument must be positional only. **Args:** * ​other (`Self`): The dictionary to update from. ### `clear` `clear(mut self)` Remove all elements from the dictionary. ### `setdefault` `setdefault(mut self, key: K, var default: V) -> ref [origin_of(*[0,0]._entries._value.value)] V` Get a value from the dictionary by key, or set it to a default if it doesn't exist. **Args:** * ​key (`K`): The key to search for in the dictionary. * ​default (`V`): The default value to set if the key is not present. **Returns:** `ref`: The value associated with the key, or the default value if it wasn't present.
--- ## DictEntry
`struct DictEntry[K: KeyElement, V: Copyable, H: Hasher]` Store a key-value pair entry inside a dictionary. ## Parameters * ​K ([`KeyElement`](/mojo/stdlib/collections/dict/#keyelement)): The key type of the dict. Must be Hashable+Equatable. * ​V ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The value type of the dict. * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The type of the hasher used to hash the key. ## Fields * ​hash (`UInt64`): `key.__hash__()`, stored so hashing isn't re-computed during dict lookup. * ​key (`K`): The unique key for the entry. * ​value (`V`): The value associated with the key. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = V.__copyinit__is_trivial if K.__copyinit__is_trivial else K.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = V.__del__is_trivial if K.__del__is_trivial else K.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = V.__moveinit__is_trivial if K.__moveinit__is_trivial else K.__moveinit__is_trivial` ## Methods ### `__init__` `__init__(out self, var key: K, var value: V)` Create an entry from a key and value, computing the hash. **Args:** * ​key (`K`): The key of the entry. * ​value (`V`): The value of the entry. ### `reap_value` `reap_value(deinit self) -> V` Take the value from an owned entry. **Returns:** `V`: The value of the entry.
--- ## OwnedKwargsDict
`struct OwnedKwargsDict[V: Copyable]` Container used to pass owned variadic keyword arguments to functions. This type mimics the interface of a dictionary with `String` keys, and should be usable more-or-less like a dictionary. Notably, however, this type should not be instantiated directly by users. ## Parameters * ​V ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The value type of the dictionary. Currently must be Copyable. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _DictKeyIter[OwnedKwargsDict[V].key_type, V, default_comp_time_hasher, iterable_origin]` The iterator type for this dictionary. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ### `key_type` `comptime key_type = String` The key type for this dictionary (always String). ## Methods ### `__init__` `__init__(out self)` Initialize an empty keyword dictionary. ### `__getitem__` `__getitem__(ref self, key: String) -> ref [origin_of($1._dict._entries._value.value)] V` Retrieve a value out of the keyword dictionary. **Args:** * ​key ([`String`](/mojo/stdlib/collections/string/string/String)): The key to retrieve. **Returns:** `ref`: The value associated with the key, if it's present. **Raises:** "KeyError" if the key isn't present. ### `__setitem__` `__setitem__(mut self, key: String, var value: V)` Set a value in the keyword dictionary by key. **Args:** * ​key ([`String`](/mojo/stdlib/collections/string/string/String)): The key to associate with the specified value. * ​value (`V`): The data to store in the dictionary. ### `__contains__` `__contains__(self, key: String) -> Bool` Check if a given key is in the keyword dictionary or not. **Args:** * ​key ([`String`](/mojo/stdlib/collections/string/string/String)): The key to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if there key exists in the keyword dictionary, False otherwise. ### `__len__` `__len__(self) -> Int` The number of elements currently stored in the keyword dictionary. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements currently stored in the keyword dictionary. ### `find` `find(self, key: String) -> Optional[V]` Find a value in the keyword dictionary by key. **Args:** * ​key ([`String`](/mojo/stdlib/collections/string/string/String)): The key to search for in the dictionary. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): An optional value containing a copy of the value if it was present, otherwise an empty Optional. ### `pop` `pop(mut self, key: String, var default: V) -> V` Remove a value from the dictionary by key. **Args:** * ​key ([`String`](/mojo/stdlib/collections/string/string/String)): The key to remove from the dictionary. * ​default (`V`): A default value to return if the key was not found instead of raising. **Returns:** `V`: The value associated with the key, if it was in the dictionary. If it wasn't, return the provided default value instead. `pop(mut self, key: String) -> V` Remove a value from the dictionary by key. **Args:** * ​key ([`String`](/mojo/stdlib/collections/string/string/String)): The key to remove from the dictionary. **Returns:** `V`: The value associated with the key, if it was in the dictionary. Raises otherwise. **Raises:** "KeyError" if the key was not present in the dictionary. ### `__iter__` `__iter__(ref self) -> _DictKeyIter[OwnedKwargsDict[V].key_type, V, default_comp_time_hasher, self_is_origin]` Iterate over the keyword dict's keys as immutable references. **Returns:** `_DictKeyIter`: An iterator of immutable references to the dictionary keys. ### `keys` `keys(ref self) -> _DictKeyIter[OwnedKwargsDict[V].key_type, V, default_comp_time_hasher, origin_of(self_is_origin._dict)]` Iterate over the keyword dict's keys as immutable references. **Returns:** `_DictKeyIter`: An iterator of immutable references to the dictionary keys. ### `values` `values(ref self) -> _DictValueIter[OwnedKwargsDict[V].key_type, V, default_comp_time_hasher, origin_of(self_is_origin._dict)]` Iterate over the keyword dict's values as references. **Returns:** [`_DictValueIter`](/mojo/stdlib/collections/dict/_DictValueIter): An iterator of references to the dictionary values. ### `items` `items(ref self) -> _DictEntryIter[OwnedKwargsDict[V].key_type, V, default_comp_time_hasher, origin_of(self_is_origin._dict)]` Iterate over the keyword dictionary's entries as immutable references. Examples: ```mojo var my_dict = Dict[String, Int]() my_dict["a"] = 1 my_dict["b"] = 2 for e in my_dict.items(): print(e.key, e.value) ``` Notes: These can't yet be unpacked like Python dict items, but you can access the key and value as attributes. **Returns:** [`_DictEntryIter`](/mojo/stdlib/collections/dict/_DictEntryIter): An iterator of immutable references to the dictionary entries.
--- ## dict (Dict)
Defines `Dict`, a collection that stores key-value pairs. Dict provides an efficient, O(1) amortized average-time complexity for insert, lookup, and removal of dictionary elements. Its implementation closely mirrors Python's `dict` implementation: * Performance and size are heavily optimized for small dictionaries, but can scale to large dictionaries. * Insertion order is implicitly preserved. Iteration over keys, values, and items have a deterministic order based on insertion. * For more information on the Mojo `Dict` type, see the [Mojo `Dict` manual](/mojo/manual/types/#dict). To learn more about using Python dictionaries from Mojo, see [Python types in Mojo](/mojo/manual/python/types/#python-types-in-mojo). Key elements must implement the `KeyElement` trait composition, which includes `Hashable`, `Equatable`, and `Copyable`. The `Copyable` requirement will eventually be removed. Value elements must be `Copyable`. As with `KeyElement`, the `Copyable` requirement for value elements will eventually be removed. See the `Dict` docs for more details. ## `comptime` values ### `KeyElement` `comptime KeyElement = Copyable & Hashable & Equatable` A trait composition for types which implement all requirements of dictionary keys. Dict keys must minimally be `Copyable`, `Hashable`, and `Equatable`. ## Structs * [​`Dict`](/mojo/stdlib/collections/dict/Dict): A container that stores key-value pairs. * [​`DictEntry`](/mojo/stdlib/collections/dict/DictEntry): Store a key-value pair entry inside a dictionary. * [​`OwnedKwargsDict`](/mojo/stdlib/collections/dict/OwnedKwargsDict): Container used to pass owned variadic keyword arguments to functions.
--- ## collections
Implements the collections package. ## Packages * [​`string`](/mojo/stdlib/collections/string/): The string package provides comprehensive Unicode string handling functionality for Mojo. ## Modules * [​`bitset`](/mojo/stdlib/collections/bitset/): Provides a compact, grow-only set of non-negative integers. * [​`counter`](/mojo/stdlib/collections/counter/): Defines the `Counter` type. * [​`deque`](/mojo/stdlib/collections/deque/): Defines the Deque type. * [​`dict`](/mojo/stdlib/collections/dict/): Defines `Dict`, a collection that stores key-value pairs. * [​`inline_array`](/mojo/stdlib/collections/inline_array/): Provides a fixed-size array implementation with compile-time size checking. * [​`interval`](/mojo/stdlib/collections/interval/): A self-balancing interval tree is a specialized binary search tree designed to efficiently store and query intervals. * [​`linked_list`](/mojo/stdlib/collections/linked_list/): * [​`list`](/mojo/stdlib/collections/list/): Defines the List type. * [​`optional`](/mojo/stdlib/collections/optional/): Defines Optional, a type modeling a value which may or may not be present. * [​`set`](/mojo/stdlib/collections/set/): Implements the Set datatype.
--- ## InlineArray
`struct InlineArray[ElementType: Copyable, size: Int]` A fixed-size sequence of homogeneous elements where size is a constant expression. InlineArray provides a fixed-size array implementation with compile-time size checking. The array size is determined at compile time and cannot be changed. Elements must implement the `Copyable` trait. Examples: ```mojo # Create array of 3 integers var arr = InlineArray[Int, 3](1, 2, 3) # Create array filled with value var filled = InlineArray[Int, 5](fill=42) # Access elements print(arr[0]) # Prints 1 ``` ## Parameters * ​ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the array. Must implement `Copyable` trait. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the array. Must be a positive integer constant. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = ElementType.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = ElementType.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = ElementType.__moveinit__is_trivial` ### `device_type` `comptime device_type = InlineArray[ElementType, size]` The device-side type for this array. ### `type` `comptime type = __mlir_type.`!pop.array<#lit.struct.extract<:!lit.struct<@stdlib::@builtin::@int::@Int> size, "\_mlir\_value">, :trait<@stdlib::@builtin::@value::@Copyable> ElementType>\`\` The underlying MLIR array type. ## Methods ### `__init__` `__init__(out self)` This constructor will always cause a compile time error if used. It is used to steer users away from uninitialized memory. `__init__(out self, *, uninitialized: Bool)` Create an InlineArray with uninitialized memory. Examples: ```mojo var uninitialized_array = InlineArray[Int, 10](uninitialized=True) ``` Notes: This constructor is unsafe and should be used with caution. The array elements will be uninitialized and accessing them before initialization is undefined behavior. **Args:** * ​uninitialized ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): A boolean to indicate if the array should be initialized. Always set to `True` (it's not actually used inside the constructor). `__init__(out self, *, var unsafe_assume_initialized: InlineArray[UnsafeMaybeUninitialized[ElementType], size])` Constructs an `InlineArray` from an `InlineArray` of `UnsafeMaybeUninitialized`. Warning: This is an unsafe constructor. Only use it if you are certain all elements are properly initialized. Notes: This constructor assumes all elements in the input array are initialized. Using uninitialized elements results in undefined behavior, even for types that are valid for any bit pattern (e.g. `Int` or `Float`). **Args:** * ​unsafe\_assume\_initialized ([`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray)): The array of `UnsafeMaybeUninitialized` elements. All elements must be initialized. `__init__[batch_size: Int = 64](out self, *, fill: ElementType)` Constructs an array where each element is initialized to the supplied value. Examples: ```mojo var filled = InlineArray[Int, 5](fill=42) # [42, 42, 42, 42, 42] # For large arrays, consider adjusting batch_size to balance # compile time and runtime performance: var large = InlineArray[Int, 10000].__init__[batch_size=32](fill=0) ``` Notes: * Full unrolling with large arrays (>2k elements) can cause significant compiler slowdowns. * Using batch\_size=64 balances AVX512 efficiency and instruction cache usage. * For very large arrays, using smaller batch sizes (e.g., 32 or 16) can further improve compilation speed while still maintaining good runtime performance. **Parameters:** * ​batch\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to unroll for filling the array. Default is 64, which optimizes for AVX512 operations on modern CPUs. For large arrays (>2k elements), this batched approach significantly improves compile times compared to full unrolling while maintaining good runtime performance. **Args:** * ​fill (`ElementType`): The element value to fill each index with. `__init__(out self, var *elems: ElementType, *, __list_literal__: Tuple[] = Tuple[]())` Constructs an array from a variadic list of elements. Examples: ```mojo var arr = InlineArray[Int, 3](1, 2, 3) # [1, 2, 3] ``` **Args:** * ​\*elems (`ElementType`): The elements to initialize the array with. Must match the array size. * ​**list\_literal** ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Specifies that this constructor can be used for list literals. `__init__[origin: MutOrigin, //](out self, *, var storage: VariadicListMem[ElementType, origin, True])` Construct an array from a low-level internal representation. **Parameters:** * ​origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): The origin of the storage being passed in. **Args:** * ​storage ([`VariadicListMem`](/mojo/stdlib/builtin/variadics/VariadicListMem)): The variadic list storage to construct from. Must match array size. ### `__copyinit__` `__copyinit__(out self, other: Self)` Copy constructs the array from another array. Examples: ```mojo var arr = InlineArray[Int, 3](1, 2, 3) var copy = arr.copy() # Creates new array [1, 2, 3] ``` **Args:** * ​other (`Self`): The array to copy from. ### `__moveinit__` `__moveinit__(out self, deinit other: Self)` Move constructs the array from another array. Notes: Moves the elements from the source array into this array. **Args:** * ​other (`Self`): The array to move from. ### `__del__` `__del__(deinit self)` Deallocates the array and destroys its elements. ### `__getitem__` `__getitem__[I: Indexer](ref self, idx: I) -> ref [self] ElementType` Gets a reference to the element at the given index. Examples: ```mojo var arr = InlineArray[Int, 3](1, 2, 3) print(arr[0]) # Prints 1 - first element print(arr[1]) # Prints 2 - second element print(arr[-1]) # Prints 3 - last element print(arr[-2]) # Prints 2 - second to last element ``` Notes: This method provides array-style indexing access to elements in the InlineArray. It supports both positive indices starting from 0 and negative indices counting backwards from the end of the array. The index is bounds-checked at runtime. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type parameter representing the index type, must implement Indexer trait. **Args:** * ​idx (`I`): The index to access. Can be positive (0 to len-1) or negative (-len to -1). **Returns:** `ref`: A reference to the element at the specified index. `__getitem__[I: Indexer, //, idx: I](ref self) -> ref [self] ElementType` Gets a reference to the element at the given index with compile-time bounds checking. Examples: ```mojo var arr = InlineArray[Int, 3](1, 2, 3) print(arr[0]) # Prints 1 - first element print(arr[-1]) # Prints 3 - last element ``` Notes: This overload provides array-style indexing with compile-time bounds checking. The index must be a compile-time constant value. It supports both positive indices starting from 0 and negative indices counting backwards from the end of the array. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type parameter representing the index type, must implement Indexer trait. * ​idx (`I`): The compile-time constant index to access. Can be positive (0 to len-1) or negative (-len to -1). **Returns:** `ref`: A reference to the element at the specified index. ### `__contains__` `__contains__[T: Equatable & Copyable, //](self: InlineArray[T, size], value: T) -> Bool` Tests if a value is present in the array using the `in` operator. Examples: ```mojo var arr = InlineArray[Int, 3](1, 2, 3) print(3 in arr) # Prints True - value exists print(4 in arr) # Prints False - value not found ``` Notes: This method enables using the `in` operator to check if a value exists in the array. It performs a linear search comparing each element for equality with the given value. The element type must implement the `Equatable` and `Copyable` traits to support equality comparison. **Parameters:** * ​T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The element type, must implement both `Equatable` and `Copyable`. **Args:** * ​value (`T`): The value to search for. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the value is found in any position in the array, False otherwise. ### `get_type_name` `static get_type_name() -> String` Gets the name of the host type (the one implementing this trait). **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The host type's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The device type's name. ### `__len__` `__len__(self) -> Int` Returns the length of the array. Examples: ```mojo var arr = InlineArray[Int, 3](1, 2, 3) print(len(arr)) # Prints 3 ``` Notes: The length is a compile-time constant value determined by the size parameter used when creating the array. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The size of the array as an Int. ### `unsafe_get` `unsafe_get[I: Indexer](ref self, idx: I) -> ref [self] ElementType` Gets a reference to an element without bounds checking. Examples: ```mojo var arr = InlineArray[Int, 3](1, 2, 3) print(arr.unsafe_get(0)) # Prints 1 ``` Warning: This is an unsafe method. No bounds checking is performed. Using an invalid index will cause undefined behavior. Negative indices are not supported. Notes: This is an unsafe method that skips bounds checking for performance. Users should prefer `__getitem__` instead for safety. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): A type parameter representing the index type, must implement Indexer trait. **Args:** * ​idx (`I`): The index of the element to get. Must be non-negative and in bounds. Using an invalid index will cause undefined behavior. **Returns:** `ref`: A reference to the element at the given index. ### `unsafe_ptr` `unsafe_ptr[origin: Origin[mut], address_space: AddressSpace, //](ref [origin, $2] self) -> UnsafePointer[ElementType, origin, address_space=address_space]` Gets an unsafe pointer to the underlying array storage. Examples: ```mojo var arr = InlineArray[Int, 3](1, 2, 3) var ptr = arr.unsafe_ptr() print(ptr[0]) # Prints 1 ``` Warning: This is an unsafe method. The returned pointer: * Becomes invalid if the array is moved * Must not be used to access memory outside array bounds * Must be refreshed after any operation that could move the array Notes: Returns a raw pointer to the array's memory that can be used for direct memory access. The pointer inherits mutability from the array reference. **Parameters:** * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the reference to self. * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The address space of the array. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): An `UnsafePointer` to the underlying array storage. The pointer's mutability matches that of the array reference.
--- ## inline_array
Provides a fixed-size array implementation with compile-time size checking. The `InlineArray` type represents a fixed-size sequence of homogeneous elements where the size is determined at compile time. It provides efficient memory layout and bounds checking while maintaining type safety. The `InlineArray` type is part of the `prelude` module and therefore does not need to be imported in order to use it. Examples: ```mojo # Create an array of 3 integers var arr = InlineArray[Int, 3](1, 2, 3) # Access elements print(arr[0]) # Prints 1 # Fill with a value var filled = InlineArray[Int, 5](fill=42) ``` ## Structs * [​`InlineArray`](/mojo/stdlib/collections/inline_array/InlineArray): A fixed-size sequence of homogeneous elements where size is a constant expression.
--- ## Interval
`struct Interval[T: IntervalElement]` A half-open interval \[start, end) that represents a range of values. The interval includes the start value but excludes the end value. ## Parameters * ​T ([`IntervalElement`](/mojo/stdlib/collections/interval/IntervalElement)): The type of the interval bounds. ## Fields * ​start (`T`): The inclusive start of the interval. * ​end (`T`): The exclusive end of the interval. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = T.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = T.__moveinit__is_trivial` ## Methods ### `__init__` `__init__(out self, start: T, end: T)` Initialize an interval with start and end values. **Args:** * ​start (`T`): The starting value of the interval. * ​end (`T`): The ending value of the interval. Must be greater than or equal to start. `__init__(out self, interval: Tuple[T, T], /)` Initialize an interval with a tuple of start and end values. **Args:** * ​interval ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): A tuple containing the start and end values. ### `__copyinit__` `__copyinit__(out self, existing: Self, /)` Create a new instance of the interval by copying the values from an existing one. **Args:** * ​existing (`Self`): The interval to copy values from. ### `__bool__` `__bool__(self) -> Bool` Returns whether this interval is empty. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the interval is not empty (start < end), False otherwise. ### `__lt__` `__lt__(self, other: Self) -> Bool` Returns whether this interval is less than another interval. **Args:** * ​other (`Self`): The interval to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this interval's start is less than the other interval's start. ### `__le__` `__le__(self, other: Self) -> Bool` Returns whether this interval is less than or equal to another interval. **Args:** * ​other (`Self`): The interval to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this interval's start is less than or equal to the other interval's start. ### `__eq__` `__eq__(self, other: Self) -> Bool` Returns whether this interval equals another interval. **Args:** * ​other (`Self`): The interval to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if both intervals have the same start and end values. ### `__gt__` `__gt__(self, other: Self) -> Bool` Returns whether this interval is greater than another interval. **Args:** * ​other (`Self`): The interval to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this interval's end is greater than the other interval's end. ### `__ge__` `__ge__(self, other: Self) -> Bool` Returns whether this interval is greater than or equal to another interval. **Args:** * ​other (`Self`): The interval to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this interval's end is greater than or equal to the other interval's end. ### `__contains__` `__contains__(self, other: T) -> Bool` Returns whether a value is contained within this interval. **Args:** * ​other (`T`): The value to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the value is within the interval bounds, False otherwise. `__contains__(self, other: Self) -> Bool` Returns whether another interval is fully contained within this interval. **Args:** * ​other (`Self`): The interval to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the other interval is fully contained within this interval, False otherwise. ### `overlaps` `overlaps(self, other: Self) -> Bool` Returns whether this interval overlaps with another interval. **Args:** * ​other (`Self`): The interval to check for overlap with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the intervals overlap, False otherwise. ### `union` `union(self, other: Self) -> Self` Returns the union of this interval and another interval. **Args:** * ​other (`Self`): The interval to union with. **Returns:** `Self`: The union of this interval and the other interval. ### `intersection` `intersection(self, other: Self) -> Self` Returns the intersection of this interval and another interval. **Args:** * ​other (`Self`): The interval to intersect with. **Returns:** `Self`: The intersection of this interval and the other interval. ### `__len__` `__len__(self) -> Int` Returns the length of this interval. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The difference between end and start values as an integer. ### `write_to` `write_to(self, mut writer: T)` Writes this interval to a writer in the format '(start, end)'. **Args:** * ​writer (`T`): The writer to write the interval to. ### `__str__` `__str__(self) -> String` Returns a string representation of this interval. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string in the format '(start, end)' representing this interval. ### `__repr__` `__repr__(self) -> String` Returns a string representation of this interval suitable for debugging. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string in the format '(start, end)' representing this interval.
--- ## IntervalElement
The trait denotes a trait composition of the `Copyable`, `Writable`, `Intable`, and `Comparable` traits. Which is also subtractable. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `__lt__` `__lt__(self: _Self, rhs: _Self) -> Bool` Define whether `self` is less than `rhs`. **Args:** * ​rhs (`_Self`): The value to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is less than `rhs`. ### `__eq__` `__eq__(self: _Self, other: _Self) -> Bool` Define whether two instances of the object are equal to each other. **Args:** * ​other (`_Self`): Another instance of the same type. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the instances are equal according to the type's definition of equality, False otherwise. ### `__sub__` `__sub__(self: _Self, rhs: _Self) -> _Self` Subtracts rhs from self, must be implemented in concrete types. **Args:** * ​rhs (`_Self`): The value to subtract from self. **Returns:** `_Self`: The result of subtracting rhs from self. ### `__int__` `__int__(self: _Self) -> Int` Get the integral representation of the value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integral representation of the value. ### `write_to` `write_to(self: _Self, mut writer: T)` Formats the string representation of this type to the provided Writer. **Args:** * ​writer (`T`): The type conforming to `Writable`. ## Provided methods ### `__le__` `__le__(self: _Self, rhs: _Self) -> Bool` Define whether `self` is less than or equal to `rhs`. **Args:** * ​rhs (`_Self`): The value to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is less than or equal to `rhs`. ### `__ne__` `__ne__(self: _Self, other: _Self) -> Bool` Define whether two instances of the object are not equal to each other. **Args:** * ​other (`_Self`): Another instance of the same type. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the instances are not equal according to the type's definition of equality, False otherwise. ### `__gt__` `__gt__(self: _Self, rhs: _Self) -> Bool` Define whether `self` is greater than `rhs`. **Args:** * ​rhs (`_Self`): The value to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is greater than `rhs`. ### `__ge__` `__ge__(self: _Self, rhs: _Self) -> Bool` Define whether `self` is greater than or equal to `rhs`. **Args:** * ​rhs (`_Self`): The value to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if `self` is greater than or equal to `rhs`. ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## IntervalTree
`struct IntervalTree[T: IntervalElement, U: Copyable & Stringable & Comparable]` An interval tree data structure for efficient range queries. ## Parameters * ​T ([`IntervalElement`](/mojo/stdlib/collections/interval/IntervalElement)): The type of the interval bounds, must support subtraction, integer conversion, string conversion, comparison and collection operations. * ​U ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Stringable`](/mojo/stdlib/builtin/str/Stringable) & [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable)): The type of the associated data, must support string conversion and collection operations. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ## Methods ### `__init__` `__init__(out self)` Initializes an empty IntervalTree. ### `__del__` `__del__(deinit self)` Destructor that frees the interval tree's memory. ### `insert` `insert(mut self, interval: Tuple[T, T], data: U)` Insert a new interval into the tree using a tuple representation. **Args:** * ​interval ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): A tuple containing the start and end values of the interval. * ​data (`U`): The data value to associate with this interval. `insert(mut self, interval: Interval[T], data: U)` Insert a new interval into the tree. This method inserts a new interval and its associated data into the interval tree. It maintains the binary search tree property based on interval start times and updates the tree structure to preserve red-black tree properties. **Args:** * ​interval ([`Interval`](/mojo/stdlib/collections/interval/Interval)): The interval to insert into the tree. * ​data (`U`): The data value to associate with this interval. ### `__str__` `__str__(self) -> String` Returns a string representation of the interval tree. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the interval tree. ### `__repr__` `__repr__(self) -> String` Returns a string representation of the interval tree suitable for debugging. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the interval tree. ### `write_to` `write_to(self, mut writer: T)` Writes the interval tree to a writer. **Args:** * ​writer (`T`): The writer to write the interval tree to. ### `depth` `depth(self) -> Int` Returns the depth of the interval tree. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The depth of the interval tree. ### `transplant` `transplant(mut self, mut u: UnsafePointer[_IntervalNode[T, U], origin_of()], mut v: UnsafePointer[_IntervalNode[T, U], origin_of()])` Transplants the subtree rooted at node u with the subtree rooted at node v. **Args:** * ​u ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): The node to transplant. * ​v ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): The node to transplant to. ### `search` `search(self, interval: Tuple[T, T]) -> List[U]` Searches for intervals overlapping with the given tuple. **Args:** * ​interval ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): The interval tuple (start, end). **Returns:** [`List`](/mojo/stdlib/collections/list/List): A list of data associated with overlapping intervals. **Raises:** If the operation fails. `search(self, interval: Interval[T]) -> List[U]` Searches for intervals overlapping with the given interval. **Args:** * ​interval ([`Interval`](/mojo/stdlib/collections/interval/Interval)): The interval to search. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A list of data associated with overlapping intervals. **Raises:** If the operation fails.
--- ## interval (Interval)
A self-balancing interval tree is a specialized binary search tree designed to efficiently store and query intervals. It maintains intervals sorted by their low endpoints and augments each node with a `max_high` attribute, representing the maximum high endpoint in its subtree. This `max_high` value enables efficient overlap searching by pruning the search space. Self-balancing mechanisms, such as Red-Black or AVL trees, ensure logarithmic time complexity for operations. Key Features: * Stores intervals (low, high). * Nodes ordered by `low` endpoints. * `max_high` attribute at each node for efficient overlap search. * Self-balancing (e.g., using Red-Black tree logic) for O(log n) operations. Operations: * Insertion: O(log n) - Adds a new interval, maintaining balance and updating `max_high`. * Overlap Search: O(log n) - Finds intervals overlapping a query interval using `max_high` for pruning. * Deletion: O(log n) - Removes an interval, maintaining balance and updating `max_high`. Space Complexity: O(n), where n is the number of intervals. Use Cases: * Calendar scheduling * Computational geometry * Genomics * Database indexing * Resource allocation In essence, this data structure provides a fast and efficient way to manage and query interval data, particularly for finding overlaps. ## Structs * [​`Interval`](/mojo/stdlib/collections/interval/Interval): A half-open interval \[start, end) that represents a range of values. * [​`IntervalTree`](/mojo/stdlib/collections/interval/IntervalTree): An interval tree data structure for efficient range queries. ## Traits * [​`IntervalElement`](/mojo/stdlib/collections/interval/IntervalElement): The trait denotes a trait composition of the `Copyable`, `Writable`, `Intable`, and `Comparable` traits. Which is also subtractable.
--- ## LinkedList
`struct LinkedList[ElementType: Copyable]` A doubly-linked list implementation. A doubly-linked list is a data structure where each element points to both the next and previous elements, allowing for efficient insertion and deletion at any position. ## Parameters * ​ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of elements stored in the list. Must implement the `Copyable` trait. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _LinkedListIter[ElementType, iterable_origin]` The iterator type for this linked list. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `__init__(out self)` Initialize an empty linked list. Notes: Time Complexity: O(1). `__init__(out self, var *elements: ElementType, *, __list_literal__: Tuple[] = Tuple[]())` Initialize a linked list with the given elements. Notes: Time Complexity: O(n) in len(elements). **Args:** * ​\*elements (`ElementType`): Variable number of elements to initialize the list with. * ​**list\_literal** ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tell Mojo to use this method for list literals. `__init__(out self, *, var elements: VariadicListMem[ElementType, origin, is_owned])` Construct a list from a `VariadicListMem`. Notes: Time Complexity: O(n) in len(elements). **Args:** * ​elements ([`VariadicListMem`](/mojo/stdlib/builtin/variadics/VariadicListMem)): The elements to add to the list. ### `__copyinit__` `__copyinit__(out self, other: Self)` Initialize this list as a copy of another list. Notes: Time Complexity: O(n) in len(elements). **Args:** * ​other (`Self`): The list to copy from. ### `__del__` `__del__(deinit self)` Clean up the list by freeing all nodes. Notes: Time Complexity: O(n) in len(self). ### `__bool__` `__bool__(self) -> Bool` Check if the list is non-empty. Notes: Time Complexity: O(1). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the list has elements, False otherwise. ### `__getitem__` `__getitem__[I: Indexer](ref self, idx: I) -> ref [self] ElementType` Get the element at the specified index. Notes: Time Complexity: O(n) in len(self). **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of index to use. **Args:** * ​idx (`I`): The index of the element to get. **Returns:** `ref`: The element at the specified index. ### `__eq__` `__eq__[_ElementType: Equatable & Copyable, //](self: LinkedList[_ElementType], other: LinkedList[_ElementType]) -> Bool` Checks if the two lists are equal. Notes: Time Complexity: O(n) in min(len(self), len(other)) compares. **Parameters:** * ​\_ElementType ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The list element type, used to conditionally enable the function. **Args:** * ​other ([`LinkedList`](/mojo/stdlib/collections/linked_list/LinkedList)): The list to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Whether the lists are equal. ### `__ne__` `__ne__[_ElementType: Equatable & Copyable, //](self: LinkedList[_ElementType], other: LinkedList[_ElementType]) -> Bool` Checks if the two lists are not equal. Notes: Time Complexity: O(n) in min(len(self), len(other)) compares. **Parameters:** * ​\_ElementType ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The list element type, used to conditionally enable the function. **Args:** * ​other ([`LinkedList`](/mojo/stdlib/collections/linked_list/LinkedList)): The list to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Whether the lists are not equal. ### `__contains__` `__contains__[_ElementType: Equatable & Copyable, //](self: LinkedList[_ElementType], value: _ElementType) -> Bool` Checks if the list contains `value`. Notes: Time Complexity: O(n) in len(self) compares. **Parameters:** * ​\_ElementType ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The list element type, used to conditionally enable the function. **Args:** * ​value (`_ElementType`): The value to search for in the list. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Whether the list contains `value`. ### `append` `append(mut self, var value: ElementType)` Add an element to the end of the list. Notes: Time Complexity: O(1). **Args:** * ​value (`ElementType`): The value to append. ### `prepend` `prepend(mut self, var value: ElementType)` Add an element to the beginning of the list. Notes: Time Complexity: O(1). **Args:** * ​value (`ElementType`): The value to prepend. ### `reverse` `reverse(mut self)` Reverse the order of elements in the list. Notes: Time Complexity: O(n) in len(self). ### `pop` `pop(mut self) -> ElementType` Remove and return the last element of the list. Notes: Time Complexity: O(1). **Returns:** `ElementType`: The last element in the list. **Raises:** If the operation fails. `pop[I: Indexer, //](mut self, var i: I) -> ElementType` Remove the ith element of the list, counting from the tail if given a negative index. Notes: Time Complexity: O(n) in len(self). **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of index to use. **Args:** * ​i (`I`): The index of the element to get. **Returns:** `ElementType`: Ownership of the indicated element. **Raises:** If the operation fails. ### `maybe_pop` `maybe_pop(mut self) -> Optional[ElementType]` Removes the tail of the list and returns it, if it exists. Notes: Time Complexity: O(1). **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): The tail of the list, if it was present. `maybe_pop[I: Indexer, //](mut self, var i: I) -> Optional[ElementType]` Remove the ith element of the list, counting from the tail if given a negative index. Notes: Time Complexity: O(n) in len(self). **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of index to use. **Args:** * ​i (`I`): The index of the element to get. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): The element, if it was found. ### `clear` `clear(mut self)` Removes all elements from the list. Notes: Time Complexity: O(n) in len(self). ### `insert` `insert[I: Indexer](mut self, idx: I, var elem: ElementType)` Insert an element `elem` into the list at index `idx`. Notes: Time Complexity: O(n) in len(self). **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of index to use. **Args:** * ​idx (`I`): The index to insert `elem` at `-len(self) <= idx <= len(self)`. * ​elem (`ElementType`): The item to insert into the list. **Raises:** When given an out of bounds index. ### `extend` `extend(mut self, var other: Self)` Extends the list with another. Notes: Time Complexity: O(1). **Args:** * ​other (`Self`): The list to append to this one. ### `count` `count[_ElementType: Equatable & Copyable, //](self: LinkedList[_ElementType], elem: _ElementType) -> UInt` Count the occurrences of `elem` in the list. Notes: Time Complexity: O(n) in len(self) compares. **Parameters:** * ​\_ElementType ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The list element type, used to conditionally enable the function. **Args:** * ​elem (`_ElementType`): The element to search for. **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): The number of occurrences of `elem` in the list. ### `__len__` `__len__(self) -> Int` Get the number of elements in the list. Notes: Time Complexity: O(1). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the list. ### `__iter__` `__iter__(ref self) -> _LinkedListIter[ElementType, self_is_origin]` Iterate over elements of the list, returning immutable references. Notes: Time Complexity: * O(1) for iterator construction. * O(n) in len(self) for a complete iteration of the list. **Returns:** `_LinkedListIter`: An iterator of immutable references to the list elements. ### `__reversed__` `__reversed__(self) -> _LinkedListIter[ElementType, self, False]` Iterate backwards over the list, returning immutable references. Notes: Time Complexity: * O(1) for iterator construction. * O(n) in len(self) for a complete iteration of the list. **Returns:** `_LinkedListIter`: A reversed iterator of immutable references to the list elements. ### `__str__` `__str__[_ElementType: Copyable & Writable](self: LinkedList[_ElementType]) -> String` Convert the list to its string representation. Notes: Time Complexity: O(n) in len(self). **Parameters:** * ​\_ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): Used to conditionally enable this function when `_ElementType` is `Writable`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String representation of the list. ### `__repr__` `__repr__[_ElementType: Copyable & Writable](self: LinkedList[_ElementType]) -> String` Convert the list to its string representation. Notes: Time Complexity: O(n) in len(self). **Parameters:** * ​\_ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): Used to conditionally enable this function when `_ElementType` is `Writable`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String representation of the list. ### `write_to` `write_to[_ElementType: Copyable & Writable](self: LinkedList[_ElementType], mut writer: T)` Write the list to the given writer. Notes: Time Complexity: O(n) in len(self). **Parameters:** * ​\_ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): Used to conditionally enable this function when `_ElementType` is `Writable`. **Args:** * ​writer (`T`): The writer to write the list to.
--- ## Node
`struct Node[ElementType: Copyable]` A node in a linked list data structure. ## Parameters * ​ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of element stored in the node. ## Fields * ​value (`ElementType`): The value stored in this node. * ​prev (`UnsafePointer[Node[ElementType], origin_of()]`): The previous node in the list. * ​next (`UnsafePointer[Node[ElementType], origin_of()]`): The next node in the list. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True if True if ElementType.__copyinit__is_trivial else ElementType.__copyinit__is_trivial else True if ElementType.__copyinit__is_trivial else ElementType.__copyinit__is_trivial` ### `__del__is_trivial` `comptime __del__is_trivial = True if True if ElementType.__del__is_trivial else ElementType.__del__is_trivial else True if ElementType.__del__is_trivial else ElementType.__del__is_trivial` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True if True if ElementType.__moveinit__is_trivial else ElementType.__moveinit__is_trivial else True if ElementType.__moveinit__is_trivial else ElementType.__moveinit__is_trivial` ## Methods ### `__init__` `__init__(out self, var value: ElementType, prev: Optional[UnsafePointer[Node[ElementType], origin_of()]], next: Optional[UnsafePointer[Node[ElementType], origin_of()]])` Initialize a new Node with the given value and optional prev/next pointers. **Args:** * ​value (`ElementType`): The value to store in this node. * ​prev ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Optional pointer to the previous node. * ​next ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Optional pointer to the next node. ### `__str__` `__str__[_ElementType: Copyable & Writable](self: Node[_ElementType]) -> String` Convert this node's value to a string representation. **Parameters:** * ​\_ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): Used to conditionally enable this function if `_ElementType` is `Writable`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String representation of the node's value. ### `write_to` `write_to[_ElementType: Copyable & Writable](self: Node[_ElementType], mut writer: T)` Write this node's value to the given writer. **Parameters:** * ​\_ElementType ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): Used to conditionally enable this function if `_ElementType` is `Writable`. **Args:** * ​writer (`T`): The writer to write the value to.
--- ## linked_list
## Structs * [​`LinkedList`](/mojo/stdlib/collections/linked_list/LinkedList): A doubly-linked list implementation. * [​`Node`](/mojo/stdlib/collections/linked_list/Node): A node in a linked list data structure.
--- ## List
`struct List[T: Copyable]` A dynamically-allocated and resizable list. This is Mojo's primary dynamic array implementation, meaning the list can grow and shrink in size at runtime. However, all elements in a `List` must be the same type `T`, determined at compile time. You can create a `List` in several ways: ```mojo # Empty list var empty_list = List[Int]() # With pre-allocated capacity var preallocated = List[String](capacity=100) # With initial size and fill value var filled = List[Float64](length=10, fill=0.0) # With initial values and inferred type (Int) var numbers = [1, 2, 3, 4, 5] ``` Be aware of the following characteristics: * **Type safety**: All elements must be the same type `T`, determined at compile time. This is more restrictive than Python's lists but it improves performance: ```mojo var int_list = [1, 2, 3] # List[Int] var str_list = ["a", "b", "c"] # List[String] var mixed = [1, "hello"] # Error! All elements must be same type ``` However, you can get around this by defining your list type as [`Variant`](/mojo/stdlib/utils/variant/Variant). This is a discriminated union type, meaning it can store any number of different types that can vary at runtime. * **Value semantics:** A `List` is value semantic by default, so assignment creates a deep copy of all elements: ```mojo var list1 = [1, 2, 3] var list2 = list1 # Deep copy list2.append(4) print(list1.__str__()) # => [1, 2, 3] print(list2.__str__()) # => [1, 2, 3, 4] ``` This is different from Python, where assignment creates a reference to the same list. For more information, read about [value semantics](/mojo/manual/values/value-semantics). * **Iteration uses immutable references**: When iterating a list, you get immutable references to the actual elements, unless you specify `ref`: ```mojo var numbers = [10, 20, 30] # Default behavior creates immutable (read-only) references for num in numbers: num += 1 # error: expression must be mutable # Using `ref` gets mutable (read-write) references for ref num in numbers: num += 1 # Modifies the original elements print(numbers.__str__()) # => [11, 21, 31] ``` * **Out of bounds access**: Accessing elements with invalid indices will cause undefined behavior: ```mojo var my_list = [1, 2, 3] print(my_list[5]) # Undefined behavior (out of bounds) ``` For safe access, you should manually check bounds or use methods that handle errors gracefully: ```mojo var my_list = [1, 2, 3] if 5 < len(my_list): print(my_list[5]) # Safe: check bounds first else: print("Index out of bounds") # Some methods like index() raise exceptions try: var idx = my_list.index(99) # Raises ValueError if not found print("Found at index:", idx) except: print("Value not found in list") ``` Examples: ```mojo var my_list = [10, 20, 30] # Add elements my_list.append(40) # [10, 20, 30, 40] my_list.insert(1, 15) # [10, 15, 20, 30, 40] my_list.extend([50, 60]) # [10, 15, 20, 30, 40, 50, 60] # Access elements print(my_list[0]) # 10 (first element) print(my_list[-1]) # 60 (last element) my_list[1] = 25 # Modify element: [10, 25, 20, 30, 40, 50, 60] # Remove elements print(my_list.pop()) # Removes and returns last element (60) print(my_list.pop(2)) # Removes element at index 2 (20) # List properties print('len:', len(my_list)) # Current number of elements print('cap:', my_list.capacity) # Current allocated capacity # Multiply a list var repeated = [1, 2] * 3 print(repeated.__str__()) # [1, 2, 1, 2, 1, 2] # Iterate over a list: var fruits = ["apple", "banana", "orange"] # Iterate by value (immutable references) for fruit in fruits: print(fruit) # Iterate backwards by value for fruit in reversed(fruits): print(fruit) # Iterate by index for i in range(len(fruits)): print(i, fruits[i]) # Concatenate with + and += fruits += ["mango"] var more_fruits = fruits + ["grape", "kiwi"] print(more_fruits.__str__()) ``` ## Parameters * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of elements stored in the list. ## Fields * ​capacity (`Int`): The amount of elements that can fit in the list without resizing it. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _ListIter[T, iterable_origin]` The iterator type for this list. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `__init__(out self)` Constructs an empty list. `__init__(out self, *, capacity: Int)` Constructs a list with the given capacity. **Args:** * ​capacity ([`Int`](/mojo/stdlib/builtin/int/Int)): The requested capacity of the list. `__init__(out self, *, length: Int, fill: T)` Constructs a list with the given capacity. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The requested length of the list. * ​fill (`T`): The element to fill each element of the list. `__init__(out self, var *values: T, *, __list_literal__: Tuple[])` Constructs a list from the given values. **Args:** * ​\*values (`T`): The values to populate the list with. * ​**list\_literal** ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tell Mojo to use this method for list literals. `__init__(out self, *, var elements: VariadicListMem[T, origin, is_owned])` Constructs a list from the given values. **Args:** * ​elements ([`VariadicListMem`](/mojo/stdlib/builtin/variadics/VariadicListMem)): The values to populate the list with. `__init__(out self, span: Span[T, origin])` Constructs a list from the a Span of values. **Args:** * ​span ([`Span`](/mojo/stdlib/memory/span/Span)): The span of values to populate the list with. `__init__[IterableType: Iterable](out self, iterable: IterableType)` Constructs a list from an iterable of values. **Parameters:** * ​IterableType ([`Iterable`](/mojo/stdlib/iter/Iterable)): The type of the `iterable` argument. **Args:** * ​iterable (`IterableType`): The iterable of values to populate the list with. `__init__(out self, *, unsafe_uninit_length: Int)` Construct a list with the specified length, with uninitialized memory. This is unsafe, as it relies on the caller initializing the elements with unsafe operations, not assigning over the uninitialized data. **Args:** * ​unsafe\_uninit\_length ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements to allocate. ### `__copyinit__` `__copyinit__(out self, existing: Self)` Creates a deep copy of the given list. **Args:** * ​existing (`Self`): The list to copy. ### `__del__` `__del__(deinit self)` Destroy all elements in the list and free its memory. ### `__bool__` `__bool__(self) -> Bool` Checks whether the list has any elements or not. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `False` if the list is empty, `True` if there is at least one element. ### `__getitem__` `__getitem__(self, slice: StridedSlice) -> Self` Gets the sequence of elements at the specified positions. **Args:** * ​slice ([`StridedSlice`](/mojo/stdlib/builtin/builtin_slice/StridedSlice)): A slice that specifies positions of the new list. **Returns:** `Self`: A new list containing the list at the specified slice. `__getitem__[origin: Origin[mut], //](ref [origin] self, slice: ContiguousSlice) -> Span[T, origin]` Gets the sequence of elements at the specified positions. **Parameters:** * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of `List`. **Args:** * ​slice ([`ContiguousSlice`](/mojo/stdlib/builtin/builtin_slice/ContiguousSlice)): A slice the specifies the positions of the new list. **Returns:** [`Span`](/mojo/stdlib/memory/span/Span): A span over the specified slice. `__getitem__[I: Indexer, //](ref self, idx: I) -> ref [self] T` Gets the list element at the given index. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): A type that can be used as an index. **Args:** * ​idx (`I`): The index of the element. **Returns:** `ref`: A reference to the element at the given index. ### `__eq__` `__eq__(self, other: Self) -> Bool` Checks if two lists are equal. Examples: ```mojo var x = [1, 2, 3] var y = [1, 2, 3] print("x and y are equal" if x == y else "x and y are not equal") ``` **Constraints:** `T` must conform to `Equatable`. **Args:** * ​other (`Self`): The list to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the lists are equal, False otherwise. ### `__ne__` `__ne__[U: Equatable & Copyable, //](self: List[U], other: List[U]) -> Bool` Checks if two lists are not equal. Examples: ```mojo var x = [1, 2, 3] var y = [1, 2, 4] print("x and y are not equal" if x != y else "x and y are equal") ``` **Parameters:** * ​U ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the list. Must implement the trait `Equatable`. **Args:** * ​other ([`List`](/mojo/stdlib/collections/list/List)): The list to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the lists are not equal, False otherwise. ### `__contains__` `__contains__[U: Equatable & Copyable, //](self: List[U], value: U) -> Bool` Verify if a given value is present in the list. Examples: ```mojo var x = [1, 2, 3] print("x contains 3" if 3 in x else "x does not contain 3") ``` **Parameters:** * ​U ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the list. Must implement the trait `Equatable`. **Args:** * ​value (`U`): The value to find. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the value is contained in the list, False otherwise. ### `__add__` `__add__(self, var other: Self) -> Self` Concatenates self with other and returns the result as a new list. **Args:** * ​other (`Self`): List whose elements will be combined with the elements of self. **Returns:** `Self`: The newly created list. ### `__mul__` `__mul__(self, x: Int) -> Self` Multiplies the list by x and returns a new list. **Args:** * ​x ([`Int`](/mojo/stdlib/builtin/int/Int)): The multiplier number. **Returns:** `Self`: The new list. ### `__iadd__` `__iadd__(mut self, var other: Self)` Appends the elements of other into self. **Args:** * ​other (`Self`): List whose elements will be appended to self. ### `__imul__` `__imul__(mut self, x: Int)` Appends the original elements of this list x-1 times or clears it if x is <= 0. ```mojo var a = [1, 2] a *= 2 # a = [1, 2, 1, 2] ``` **Args:** * ​x ([`Int`](/mojo/stdlib/builtin/int/Int)): The multiplier number. ### `__iter__` `__iter__(ref self) -> _ListIter[T, self_is_origin]` Iterate over elements of the list, returning immutable references. **Returns:** `_ListIter`: An iterator of immutable references to the list elements. ### `__reversed__` `__reversed__(ref self) -> _ListIter[T, self_is_origin, False]` Iterate backwards over the list, returning immutable references. **Returns:** `_ListIter`: A reversed iterator of immutable references to the list elements. ### `__len__` `__len__(self) -> Int` Gets the number of elements in the list. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the list. ### `__str__` `__str__(self) -> String` Returns a string representation of a `List`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the list. ### `write_to` `write_to(self, mut writer: T)` Write `my_list.__str__()` to a `Writer`. **Constraints:** `T` must conform to `Representable`. **Args:** * ​writer (`T`): The object to write to. ### `__repr__` `__repr__(self) -> String` Returns a string representation of a `List`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the list. ### `byte_length` `byte_length(self) -> Int` Gets the byte length of the List (`len(self) * size_of[T]()`). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The byte length of the List (`len(self) * size_of[T]()`). ### `append` `append(mut self, var value: T)` Appends a value to this list. Notes: If there is no capacity left, resizes to twice the current capacity. Except for 0 capacity where it sets 1. **Args:** * ​value (`T`): The value to append. ### `insert` `insert(mut self, i: Int, var value: T)` Inserts a value to the list at the given index. `a.insert(len(a), value)` is equivalent to `a.append(value)`. **Args:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index for the value. * ​value (`T`): The value to insert. ### `extend` `extend(mut self, var other: Self)` Extends this list by consuming the elements of `other`. **Args:** * ​other (`Self`): List whose elements will be added in order at the end of this list. `extend(mut self, elements: Span[T, origin])` Extend this list by copying elements from a `Span`. The resulting list will have the length `len(self) + len(elements)`. **Args:** * ​elements ([`Span`](/mojo/stdlib/memory/span/Span)): The elements to copy into this list. `extend[dtype: DType, //](mut self: List[Scalar[dtype]], value: SIMD[dtype, size])` Extends this list with the elements of a vector. Notes: If there is no capacity left, resizes to `len(self) + value.size`. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType. **Args:** * ​value ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to append. `extend[dtype: DType, //](mut self: List[Scalar[dtype]], value: SIMD[dtype, size], *, count: Int)` Extends this list with `count` number of elements from a vector. Notes: If there is no capacity left, resizes to `len(self) + count`. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType. **Args:** * ​value ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The value to append. * ​count ([`Int`](/mojo/stdlib/builtin/int/Int)): The amount of items to append. Must be less than or equal to `value.size`. `extend[dtype: DType, //](mut self: List[Scalar[dtype]], value: Span[Scalar[dtype], origin])` Extends this list with the elements of a `Span`. Notes: If there is no capacity left, resizes to `len(self) + len(value)`. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The DType. **Args:** * ​value ([`Span`](/mojo/stdlib/memory/span/Span)): The value to append. ### `pop` `pop(mut self, i: Int = -1) -> T` Pops a value from the list at the given index. **Args:** * ​i ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the value to pop. **Returns:** `T`: The popped value. ### `reserve` `reserve(mut self, new_capacity: Int)` Reserves the requested capacity. Notes: If the current capacity is greater or equal, this is a no-op. Otherwise, the storage is reallocated and the date is moved. **Args:** * ​new\_capacity ([`Int`](/mojo/stdlib/builtin/int/Int)): The new capacity. ### `resize` `resize(mut self, new_size: Int, value: T)` Resizes the list to the given new size. Notes: If the new size is smaller than the current one, elements at the end are discarded. If the new size is larger than the current one, the list is appended with new values elements up to the requested size. **Args:** * ​new\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The new size. * ​value (`T`): The value to use to populate new elements. `resize(mut self, *, unsafe_uninit_length: Int)` Resizes the list to the given new size leaving any new elements uninitialized. If the new size is smaller than the current one, elements at the end are discarded. If the new size is larger than the current one, the list is extended and the new elements are left uninitialized. **Args:** * ​unsafe\_uninit\_length ([`Int`](/mojo/stdlib/builtin/int/Int)): The new size. ### `shrink` `shrink(mut self, new_size: Int)` Resizes to the given new size which must be <= the current size. Notes: With no new value provided, the new size must be smaller than or equal to the current one. Elements at the end are discarded. **Args:** * ​new\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The new size. ### `reverse` `reverse(mut self)` Reverses the elements of the list. ### `index` `index[C: Equatable & Copyable, //](ref self: List[C], value: C, start: Int = 0, stop: Optional[Int] = None) -> Int` Returns the index of the first occurrence of a value in a list restricted by the range given the start and stop bounds. Examples: ```mojo var my_list = [1, 2, 3] print(my_list.index(2)) # prints `1` ``` **Parameters:** * ​C ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the list. Must implement the `Equatable` trait. **Args:** * ​value (`C`): The value to search for. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The starting index of the search, treated as a slice index (defaults to 0). * ​stop ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The ending index of the search, treated as a slice index (defaults to None, which means the end of the list). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The index of the first occurrence of the value in the list. **Raises:** ValueError: If the value is not found in the list. ### `clear` `clear(mut self)` Clears the elements in the list. ### `steal_data` `steal_data(mut self) -> UnsafePointer[T, origin_of()]` Take ownership of the underlying pointer from the list. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The underlying data. ### `unsafe_get` `unsafe_get(ref self, idx: Int) -> ref [self] T` Get a reference to an element of self without checking index bounds. Notes: Users should consider using `__getitem__` instead of this method as it is unsafe. If an index is out of bounds, this method will not abort, it will be considered undefined behavior. Note that there is no wraparound for negative indices, caution is advised. Using negative indices is considered undefined behavior. Never use `my_list.unsafe_get(-1)` to get the last element of the list. Instead, do `my_list.unsafe_get(len(my_list) - 1)`. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to get. **Returns:** `ref`: A reference to the element at the given index. ### `unsafe_set` `unsafe_set(mut self, idx: Int, var value: T)` Write a value to a given location without checking index bounds. Notes: Users should consider using `my_list[idx] = value` instead of this method as it is unsafe. If an index is out of bounds, this method will not abort, it will be considered undefined behavior. Note that there is no wraparound for negative indices, caution is advised. Using negative indices is considered undefined behavior. Never use `my_list.unsafe_set(-1, value)` to set the last element of the list. Instead, do `my_list.unsafe_set(len(my_list) - 1, value)`. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to set. * ​value (`T`): The value to set. ### `count` `count[_T: Equatable & Copyable, //](self: List[_T], value: _T) -> Int` Counts the number of occurrences of a value in the list. **Parameters:** * ​\_T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the list. Must implement the trait `Equatable`. **Args:** * ​value (`_T`): The value to count. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of occurrences of the value in the list. ### `swap_elements` `swap_elements(mut self, elt_idx_1: Int, elt_idx_2: Int)` Swaps elements at the specified indexes if they are different. Examples: ```mojo var my_list = [1, 2, 3] my_list.swap_elements(0, 2) print(my_list.__str__()) # 3, 2, 1 ``` Notes: This is useful because `swap(my_list[i], my_list[j])` cannot be supported by Mojo, because a mutable alias may be formed. **Args:** * ​elt\_idx\_1 ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of one element. * ​elt\_idx\_2 ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the other element. ### `unsafe_ptr` `unsafe_ptr[origin: Origin[mut], address_space: AddressSpace, //](ref [origin, $2] self) -> UnsafePointer[T, origin, address_space=address_space]` Retrieves a pointer to the underlying memory. **Parameters:** * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the `List`. * ​address\_space ([`AddressSpace`](/mojo/stdlib/memory/pointer/AddressSpace)): The `AddressSpace` of the `List`. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The pointer to the underlying memory.
--- ## list (List)
Defines the List type. These APIs are imported automatically, just like builtins. ## Structs * [​`List`](/mojo/stdlib/collections/list/List): A dynamically-allocated and resizable list.
--- ## Optional
`struct Optional[T: Movable]` A type modeling a value which may or may not be present. Optional values can be thought of as a type-safe nullable pattern. Your value can take on a value or `None`, and you need to check and explicitly extract the value to get it out. Currently T is required to be a `Copyable` so we can implement copy/move for Optional and allow it to be used in collections itself. Examples: ```mojo var a = Optional(1) var b = Optional[Int](None) if a: print(a.value()) # prints 1 if b: # Bool(b) is False, so no print print(b.value()) var c = a.or_else(2) var d = b.or_else(2) print(c) # prints 1 print(d) # prints 2 ``` ## Parameters * ​T ([`Movable`](/mojo/stdlib/builtin/value/Movable)): The type of value stored in the `Optional`. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Iterator`](/mojo/stdlib/iter/Iterator), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ### `Element` `comptime Element = T` The element type of this optional. ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = Optional[T]` The iterator type for this optional. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `__init__(out self)` Construct an empty `Optional`. `@implicit` `__init__(out self, var value: T)` Construct an `Optional` containing a value. **Args:** * ​value (`T`): The value to store in the `Optional`. `@implicit` `__init__(out self, value: NoneType)` Construct an empty `Optional`. **Args:** * ​value ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): Must be exactly `None`. ### `__bool__` `__bool__(self) -> Bool` Return true if the Optional has a value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `Optional` has a value and False otherwise. ### `__getitem__` `__getitem__(ref self) -> ref [origin_of($1._value)] T` Retrieve a reference to the value inside the `Optional`. **Returns:** `ref`: A reference to the value inside the `Optional`. **Raises:** On empty `Optional`. ### `__invert__` `__invert__(self) -> Bool` Return False if the `Optional` has a value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): False if the `Optional` has a value and True otherwise. ### `__eq__` `__eq__(self, rhs: NoneType) -> Bool` Return `True` if a value is not present. **Args:** * ​rhs ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): The `None` value to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `True` if a value is not present, `False` otherwise. `__eq__[_T: Equatable & Copyable](self: Optional[_T], rhs: Optional[_T]) -> Bool` Return `True` if this is the same as another `Optional` value, meaning both are absent, or both are present and have the same underlying value. **Parameters:** * ​\_T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the list. Must implement the traits `Copyable` and `Equatable`. **Args:** * ​rhs ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The value to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the values are the same. ### `__ne__` `__ne__(self, rhs: NoneType) -> Bool` Return `True` if a value is present. **Args:** * ​rhs ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): The `None` value to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): `False` if a value is not present, `True` otherwise. `__ne__[_T: Equatable & Copyable, //](self: Optional[_T], rhs: Optional[_T]) -> Bool` Return `False` if this is the same as another `Optional` value, meaning both are absent, or both are present and have the same underlying value. **Parameters:** * ​\_T ([`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the elements in the list. Must implement the traits `Copyable` and `Equatable`. **Args:** * ​rhs ([`Optional`](/mojo/stdlib/collections/optional/Optional)): The value to compare to. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): False if the values are the same. ### `__is__` `__is__(self, other: NoneType) -> Bool` Return `True` if the Optional has no value. Notes: It allows you to use the following syntax: `if my_optional is None:`. **Args:** * ​other ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): The value to compare to (None). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the Optional has no value and False otherwise. ### `__isnot__` `__isnot__(self, other: NoneType) -> Bool` Return `True` if the Optional has a value. Notes: It allows you to use the following syntax: `if my_optional is not None:`. **Args:** * ​other ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): The value to compare to (None). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the Optional has a value and False otherwise. ### `__iter__` `__iter__(ref self) -> Self` Iterate over the Optional's possibly contained value. Optionals act as a collection of size 0 or 1. **Returns:** `Self`: An iterator over the Optional's value (if present). ### `__has_next__` `__has_next__(self) -> Bool` Return true if the Optional has a value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the Optional contains a value, False otherwise. ### `__next__` `__next__(mut self) -> Optional[T].Element` Return the contained value of the Optional. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): The value contained in the Optional. ### `bounds` `bounds(self) -> Tuple[Int, Optional[Int]]` Return the bounds of the `Optional`, which is 0 or 1. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): A tuple containing the length (0 or 1) and an `Optional` containing the length. ### `__str__` `__str__(self) -> String` Return the string representation of the value of the `Optional`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the `Optional`. ### `__repr__` `__repr__(self) -> String` Returns the verbose string representation of the `Optional`. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A verbose string representation of the `Optional`. ### `__merge_with__` `__merge_with__[other_type: AnyStruct[Bool]](self) -> Bool` Merge with other bools in an expression. **Parameters:** * ​other\_type (`AnyStruct`): The type of the bool to merge with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): A Bool after merging with the specified `other_type`. ### `write_to` `write_to(self, mut writer: T)` Write `Optional` string representation to a `Writer`. **Args:** * ​writer (`T`): The object to write to. ### `value` `value(ref self) -> ref [origin_of($1._value)] T` Retrieve a reference to the value of the `Optional`. Notes: This will abort on empty `Optional`. **Returns:** `ref`: A reference to the contained data of the `Optional` as a reference. ### `unsafe_value` `unsafe_value(ref self) -> ref [origin_of($1._value)] T` Unsafely retrieve a reference to the value of the `Optional`. Notes: This will **not** abort on empty `Optional`. **Returns:** `ref`: A reference to the contained data of the `Optional` as a reference. ### `take` `take(mut self) -> T` Move the value out of the `Optional`. Notes: This will abort on empty `Optional`. **Returns:** `T`: The contained data of the `Optional` as an owned T value. ### `unsafe_take` `unsafe_take(mut self) -> T` Unsafely move the value out of the `Optional`. Notes: This will **not** abort on empty `Optional`. **Returns:** `T`: The contained data of the `Optional` as an owned T value. ### `or_else` `or_else(deinit self, var default: T) -> T` Return the underlying value contained in the `Optional` or a default value if the `Optional`'s underlying value is not present. **Args:** * ​default (`T`): The new value to use if no value was present. **Returns:** `T`: The underlying value contained in the `Optional` or a default value. ### `copied` `copied[mut: Bool, origin: Origin[mut], //, _T: Copyable](self: Optional[Pointer[_T, origin]]) -> Optional[_T]` Converts an `Optional` containing a Pointer to an `Optional` of an owned value by copying. Examples: Copy the value of an `Optional[Pointer[_]]` ```mojo var data = "foo" var opt = Optional(Pointer(to=data)) var opt_owned: Optional[String] = opt.copied() ``` Notes: If `self` is an empty `Optional`, the returned `Optional` will be empty as well. **Parameters:** * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Mutability of the pointee origin. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): Origin of the contained `Pointer`. * ​\_T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable)): Type of the owned result value. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): An `Optional` containing an owned copy of the pointee value.
--- ## OptionalReg
`@register_passable(trivial)` `struct OptionalReg[T: AnyTrivialRegType]` A register-passable optional type. This struct optionally contains a value. It only works with trivial register passable types at the moment. ## Parameters * ​T ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The type of value stored in the Optional. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = OptionalReg[T]` The device-side type for this optional register. ## Methods ### `__init__` `__init__() -> Self` Create an optional with a value of None. `@implicit` `__init__(value: T) -> Self` Create an optional with a value. **Args:** * ​value (`T`): The value. `@implicit` `__init__(value: NoneType) -> Self` Create an optional without a value from a None literal. **Args:** * ​value ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): The None value. ### `__bool__` `__bool__(self) -> Bool` Return true if the optional has a value. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the optional has a value and False otherwise. ### `__is__` `__is__(self, other: NoneType) -> Bool` Return `True` if the Optional has no value. It allows you to use the following syntax: `if my_optional is None:` **Args:** * ​other ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): The value to compare to (None). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the Optional has no value and False otherwise. ### `__isnot__` `__isnot__(self, other: NoneType) -> Bool` Return `True` if the Optional has a value. It allows you to use the following syntax: `if my_optional is not None:` **Args:** * ​other ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): The value to compare to (None). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the Optional has a value and False otherwise. ### `get_type_name` `static get_type_name() -> String` Get the human-readable type name for this `OptionalReg` type. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the type, e.g. `OptionalReg[Int]`. ### `get_device_type_name` `static get_device_type_name() -> String` Get the human-readable device type name for this `OptionalReg` type. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation of the device type (same as type name for `OptionalReg`). ### `__merge_with__` `__merge_with__[other_type: AnyStruct[Bool]](self) -> Bool` Merge with other bools in an expression. **Parameters:** * ​other\_type (`AnyStruct`): The type of the bool to merge with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): A Bool after merging with the specified `other_type`. ### `value` `value(self) -> T` Get the optional value. **Returns:** `T`: The contained value. ### `or_else` `or_else(var self, var default: T) -> T` Return the underlying value contained in the Optional or a default value if the Optional's underlying value is not present. **Args:** * ​default (`T`): The new value to use if no value was present. **Returns:** `T`: The underlying value contained in the Optional or a default value.
--- ## optional (Optional)
Defines Optional, a type modeling a value which may or may not be present. Optional values can be thought of as a type-safe nullable pattern. Your value can take on a value or `None`, and you need to check and explicitly extract the value to get it out. Examples: ```mojo var a = Optional(1) var b = Optional[Int](None) if a: print(a.value()) # prints 1 if b: # Bool(b) is False, so no print print(b.value()) var c = a.or_else(2) var d = b.or_else(2) print(c) # prints 1 print(d) # prints 2 ``` ## Structs * [​`Optional`](/mojo/stdlib/collections/optional/Optional): A type modeling a value which may or may not be present. * [​`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg): A register-passable optional type.
--- ## Set
`struct Set[T: KeyElement, H: Hasher = default_hasher]` A set data type. O(1) average-case amortized add, remove, and membership check. ```mojo from collections import Set var set = { 1, 2, 3 } print(len(set)) # 3 set.add(4) for element in set: print(element) set -= Set[Int](3, 4, 5) print(set == Set[Int](1, 2)) # True print(set | Set[Int](0, 1) == Set[Int](0, 1, 2)) # True var element = set.pop() print(len(set)) # 1 ``` ## Parameters * ​T ([`KeyElement`](/mojo/stdlib/collections/dict/#keyelement)): The element type of the set. Must implement KeyElement. * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The type of the hasher used to hash keys. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = _DictKeyIter[T, NoneType, H, iterable_origin]` The iterator type for this set. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__init__` `__init__(out self, *ts: T, *, __set_literal__: Tuple[] = Tuple[]())` Construct a set from initial elements. **Args:** * ​\*ts (`T`): Variadic of elements to add to the set. * ​**set\_literal** ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tell Mojo to use this method for set literals. `__init__(out self, elements: List[T])` Construct a set from a List of elements. **Args:** * ​elements ([`List`](/mojo/stdlib/collections/list/List)): A vector of elements to add to the set. ### `__bool__` `__bool__(self) -> Bool` Whether the set is non-empty or not. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the set is non-empty, False if it is empty. ### `__lt__` `__lt__(self, other: Self) -> Bool` Overloads the < operator for strict subset comparison of sets. **Args:** * ​other (`Self`): The set to compare against for the strict subset relationship. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the set is a strict subset of the `other` set, False otherwise. ### `__le__` `__le__(self, other: Self) -> Bool` Overloads the <= operator for sets. Works like as `issubset` method. **Args:** * ​other (`Self`): Another Set instance to check against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this set is a subset of the `other` set, False otherwise. ### `__eq__` `__eq__(self, other: Self) -> Bool` Set equality. **Args:** * ​other (`Self`): Another Set instance to check equality against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the sets contain the same elements and False otherwise. ### `__gt__` `__gt__(self, other: Self) -> Bool` Overloads the > operator for strict superset comparison of sets. **Args:** * ​other (`Self`): The set to compare against for the strict superset relationship. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the set is a strict superset of the `other` set, False otherwise. ### `__ge__` `__ge__(self, other: Self) -> Bool` Overloads the >= operator for sets. Works like as `issuperset` method. **Args:** * ​other (`Self`): Another Set instance to check against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this set is a superset of the `other` set, False otherwise. ### `__contains__` `__contains__(self, t: T) -> Bool` Whether or not the set contains an element. **Args:** * ​t (`T`): The element to check membership in the set. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): Whether or not the set contains the element. ### `__sub__` `__sub__(self, other: Self) -> Self` Set subtraction. **Args:** * ​other (`Self`): Another Set instance to subtract from this one. **Returns:** `Self`: A new set containing elements of this set, but not containing any elements which were in the `other` set. ### `__and__` `__and__(self, other: Self) -> Self` The set intersection operator. **Args:** * ​other (`Self`): Another Set instance to intersect with this one. **Returns:** `Self`: A new set containing only the elements which appear in both this set and the `other` set. ### `__or__` `__or__(self, other: Self) -> Self` The set union operator. **Args:** * ​other (`Self`): Another Set instance to union with this one. **Returns:** `Self`: A new set containing any elements which appear in either this set or the `other` set. ### `__xor__` `__xor__(self, other: Self) -> Self` Overloads the ^ operator for sets. Works like as `symmetric_difference` method. **Args:** * ​other (`Self`): The set to find the symmetric difference with. **Returns:** `Self`: A new set containing the symmetric difference of the two sets. ### `__isub__` `__isub__(mut self, other: Self)` In-place set subtraction. Updates the set to remove any elements from the `other` set. **Args:** * ​other (`Self`): Another Set instance to subtract from this one. ### `__iand__` `__iand__(mut self, other: Self)` In-place set intersection. Updates the set to contain only the elements which are already in the set and are also contained in the `other` set. **Args:** * ​other (`Self`): Another Set instance to intersect with this one. ### `__ixor__` `__ixor__(mut self, other: Self)` Overloads the ^= operator. Works like as `symmetric_difference_update` method. Updates the set with the symmetric difference of itself and another set. **Args:** * ​other (`Self`): The set to find the symmetric difference with. ### `__ior__` `__ior__(mut self, other: Self)` In-place set union. Updates the set to contain all elements in the `other` set as well as keeping all elements it already contained. **Args:** * ​other (`Self`): Another Set instance to union with this one. ### `__len__` `__len__(self) -> Int` The size of the set. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the set. ### `__hash__` `__hash__[_H: Hasher](self, mut hasher: _H)` Updates hasher with the underlying values. The update is order independent, so s1 == s2 -> hash(s1) == hash(s2). **Parameters:** * ​\_H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`_H`): The hasher instance. ### `__str__` `__str__[U: Copyable & Hashable & Equatable & Representable, //](self: Set[U]) -> String` Returns the string representation of the set. **Parameters:** * ​U ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable) & [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Representable`](/mojo/stdlib/builtin/repr/Representable)): The type of the List elements. Must implement the `Representable` and `KeyElement` traits. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the set. ### `__repr__` `__repr__[U: Copyable & Hashable & Equatable & Representable, //](self: Set[U]) -> String` Returns the string representation of the set. **Parameters:** * ​U ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable) & [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Representable`](/mojo/stdlib/builtin/repr/Representable)): The type of the List elements. Must implement the `Representable` and `KeyElement` traits. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string representation of the set. ### `write_to` `write_to[U: Copyable & Hashable & Equatable & Representable, //](self: Set[U], mut writer: T)` Write Set string representation to a `Writer`. **Parameters:** * ​U ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable) & [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable) & [`Representable`](/mojo/stdlib/builtin/repr/Representable)): The type of the List elements. Must implement the `Representable` and `KeyElement` traits. **Args:** * ​writer (`T`): The object to write to. ### `__iter__` `__iter__(ref self) -> _DictKeyIter[T, NoneType, H, self_is_origin]` Iterate over elements of the set, returning immutable references. **Returns:** `_DictKeyIter`: An iterator of immutable references to the set elements. ### `add` `add(mut self, t: T)` Add an element to the set. **Args:** * ​t (`T`): The element to add to the set. ### `remove` `remove(mut self, t: T)` Remove an element from the set. **Args:** * ​t (`T`): The element to remove from the set. **Raises:** If the element isn't in the set to remove. ### `pop` `pop(mut self) -> T` Remove any one item from the set, and return it. As an implementation detail this will remove the first item according to insertion order. This is practically useful for breadth-first search implementations. **Returns:** `T`: The element which was removed from the set. **Raises:** If the set is empty. ### `union` `union(self, other: Self) -> Self` Set union. **Args:** * ​other (`Self`): Another Set instance to union with this one. **Returns:** `Self`: A new set containing any elements which appear in either this set or the `other` set. ### `intersection` `intersection(self, other: Self) -> Self` Set intersection. **Args:** * ​other (`Self`): Another Set instance to intersect with this one. **Returns:** `Self`: A new set containing only the elements which appear in both this set and the `other` set. ### `difference` `difference(self, other: Self) -> Self` Set difference. **Args:** * ​other (`Self`): Another Set instance to find the difference with this one. **Returns:** `Self`: A new set containing elements that are in this set but not in the `other` set. ### `update` `update(mut self, other: Self)` In-place set update. Updates the set to contain all elements in the `other` set as well as keeping all elements it already contained. **Args:** * ​other (`Self`): Another Set instance to union with this one. ### `intersection_update` `intersection_update(mut self, other: Self)` In-place set intersection update. Updates the set by retaining only elements found in both this set and the `other` set, removing all other elements. The result is the intersection of this set with `other`. **Args:** * ​other (`Self`): Another Set instance to intersect with this one. ### `difference_update` `difference_update(mut self, other: Self)` In-place set subtraction. Updates the set by removing all elements found in the `other` set, effectively keeping only elements that are unique to this set. **Args:** * ​other (`Self`): Another Set instance to subtract from this one. ### `issubset` `issubset(self, other: Self) -> Bool` Check if this set is a subset of another set. **Args:** * ​other (`Self`): Another Set instance to check against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this set is a subset of the `other` set, False otherwise. ### `isdisjoint` `isdisjoint(self, other: Self) -> Bool` Check if this set is disjoint with another set. **Args:** * ​other (`Self`): Another Set instance to check against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this set is disjoint with the `other` set, False otherwise. ### `issuperset` `issuperset(self, other: Self) -> Bool` Check if this set is a superset of another set. **Args:** * ​other (`Self`): Another Set instance to check against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this set is a superset of the `other` set, False otherwise. ### `symmetric_difference` `symmetric_difference(self, other: Self) -> Self` Returns the symmetric difference of two sets. **Args:** * ​other (`Self`): The set to find the symmetric difference with. **Returns:** `Self`: A new set containing the symmetric difference of the two sets. ### `symmetric_difference_update` `symmetric_difference_update(mut self, other: Self)` Updates the set with the symmetric difference of itself and another set. **Args:** * ​other (`Self`): The set to find the symmetric difference with. ### `discard` `discard(mut self, value: T)` Remove a value from the set if it exists. Pass otherwise. **Args:** * ​value (`T`): The element to remove from the set. ### `clear` `clear(mut self)` Removes all elements from the set. This method modifies the set in-place, removing all of its elements. After calling this method, the set will be empty.
--- ## set (Set)
Implements the Set datatype. ## Structs * [​`Set`](/mojo/stdlib/collections/set/Set): A set data type.
--- ## Codepoint
`struct Codepoint` A Unicode codepoint, typically a single user-recognizable character; restricted to valid Unicode scalar values. This type is restricted to store a single Unicode [*scalar value*][1], typically encoding a single user-recognizable character. All valid Unicode scalar values are in the range(s) 0 to 0xD7FF and 0xE000 to 0x10FFFF, inclusive. This type guarantees that the stored integer value falls in these ranges. [1]: https://www.unicode.org/glossary/#unicode_scalar_value **Codepoints versus Scalar Values** Formally, Unicode defines a codespace of values in the range 0 to 0x10FFFF inclusive, and a [Unicode codepoint](https://www.unicode.org/glossary/#code_point) is any integer falling within that range. However, due to historical reasons, it became necessary to "carve out" a subset of the codespace, excluding codepoints in the range 0xD7FF–0xE000. That subset of codepoints excluding that range are known as [Unicode scalar values][1]. The codepoints in the range 0xD7FF-0xE000 are known as "surrogate" codepoints. The surrogate codepoints will never be assigned a semantic meaning, and can only validly appear in UTF-16 encoded text. The difference between codepoints and scalar values is a technical distinction related to the backwards-compatible workaround chosen to enable UTF-16 to encode the full range of the Unicode codespace. For simplicities sake, and to avoid a confusing clash with the Mojo `Scalar` type, this type is pragmatically named `Codepoint`, even though it is restricted to valid scalar values. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Intable`](/mojo/stdlib/builtin/int/Intable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(out self, *, unsafe_unchecked_codepoint: UInt32)` Construct a `Codepoint` from a code point value without checking that it falls in the valid range. Safety: The provided codepoint value MUST be a valid Unicode scalar value. Providing a value outside of the valid range could lead to undefined behavior in algorithms that depend on the validity guarantees of this type. **Args:** * ​unsafe\_unchecked\_codepoint ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): A valid Unicode scalar value code point. `__init__(out self, codepoint: UInt8)` Construct a `Codepoint` from a single byte value. This constructor cannot fail because non-negative 8-bit integers are valid Unicode scalar values. **Args:** * ​codepoint ([`UInt8`](/mojo/stdlib/builtin/simd/#uint8)): The 8-bit codepoint value to convert to a `Codepoint`. ### `__lt__` `__lt__(self, other: Self) -> Bool` Return True if this character is less than a different codepoint value from `other`. **Args:** * ​other (`Self`): The codepoint value to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this character's value is less than the other codepoint value; False otherwise. ### `__eq__` `__eq__(self, other: Self) -> Bool` Return True if this character has the same codepoint value as `other`. **Args:** * ​other (`Self`): The codepoint value to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this character and `other` have the same codepoint value; False otherwise. ### `from_u32` `static from_u32(codepoint: UInt32) -> Optional[Codepoint]` Construct a `Codepoint` from a code point value. Returns None if the provided `codepoint` is not in the valid range. **Args:** * ​codepoint ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): An integer representing a Unicode scalar value. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): A `Codepoint` if `codepoint` falls in the valid range for Unicode scalar values, otherwise None. ### `ord` `static ord(string: StringSlice[origin]) -> Self` Returns the `Codepoint` that represents the given single-character string. Given a string containing one character, return a `Codepoint` representing the codepoint of that character. For example, `Codepoint.ord("a")` returns the codepoint `97`. This is the inverse of the `chr()` function. This function is similar to the `ord()` free function, except that it returns a `Codepoint` instead of an `Int`. **Args:** * ​string ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The input string, which must contain only a single character. **Returns:** `Self`: A `Codepoint` representing the codepoint of the given character. ### `unsafe_decode_utf8_codepoint` `static unsafe_decode_utf8_codepoint(s: Span[UInt8, origin]) -> Tuple[Codepoint, Int]` Decodes a single `Codepoint` and number of bytes read from a given UTF-8 string pointer. Safety: `_ptr` MUST point to the first byte in a **known-valid** UTF-8 character sequence. This function MUST NOT be used on unvalidated input. **Args:** * ​s ([`Span`](/mojo/stdlib/memory/span/Span)): Span to UTF-8 encoded data containing at least one valid encoded codepoint. **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): The decoded codepoint `Codepoint`, as well as the number of bytes read. ### `__int__` `__int__(self) -> Int` Returns the numeric value of this scalar value as an integer. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The numeric value of this scalar value as an integer. ### `__str__` `__str__(self) -> String` Formats this `Codepoint` as a single-character string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing this single character. ### `write_to` `write_to(self, mut w: T)` Write a string representation of this `Codepoint` to the given writer. **Args:** * ​w (`T`): The object to write to. ### `is_ascii` `is_ascii(self) -> Bool` Returns True if this `Codepoint` is an ASCII character. All ASCII characters are less than or equal to codepoint value 127, and take exactly 1 byte to encode in UTF-8. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): A boolean indicating if this `Codepoint` is an ASCII character. ### `is_ascii_digit` `is_ascii_digit(self) -> Bool` Determines whether the given character is a digit \[0-9]. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the character is a digit. ### `is_ascii_upper` `is_ascii_upper(self) -> Bool` Determines whether the given character is an uppercase character. This currently only respects the default "C" locale, i.e. returns True iff the character specified is one of "ABCDEFGHIJKLMNOPQRSTUVWXYZ". **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the character is uppercase. ### `is_ascii_lower` `is_ascii_lower(self) -> Bool` Determines whether the given character is an lowercase character. This currently only respects the default "C" locale, i.e. returns True iff the character specified is one of "abcdefghijklmnopqrstuvwxyz". **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the character is lowercase. ### `is_ascii_printable` `is_ascii_printable(self) -> Bool` Determines whether the given character is a printable character. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the character is a printable character, otherwise False. ### `is_python_space` `is_python_space(self) -> Bool` Determines whether this character is a Python whitespace string. This corresponds to Python's [universal separators](https://docs.python.org/3/library/stdtypes.html#str.splitlines): `" \t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029"`. # Examples Check if a string contains only whitespace: ```mojo from testing import assert_true # ASCII space characters assert_true(Codepoint.ord(" ").is_python_space()) assert_true(Codepoint.ord(" ").is_python_space()) # Unicode paragraph separator: assert_true(Codepoint.from_u32(0x2029).value().is_python_space()) # Letters are not space characters assert_fales(Codepoint.ord("a").is_python_space()) ``` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this character is one of the whitespace characters listed above, otherwise False. ### `is_posix_space` `is_posix_space(self) -> Bool` Returns True if this `Codepoint` is a **space** character according to the [POSIX locale][1]. The POSIX locale is also known as the C locale. [1]: https://pubs.opengroup.org/onlinepubs/9699919799/basedefs/V1_chap07.html#tag_07_03_01 This only respects the default "C" locale, i.e. returns True only if the character specified is one of " \t\n\v\f\r". For semantics similar to Python, use `String.isspace()`. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True iff the character is one of the whitespace characters listed above. ### `to_u32` `to_u32(self) -> UInt32` Returns the numeric value of this scalar value as an unsigned 32-bit integer. **Returns:** [`UInt32`](/mojo/stdlib/builtin/simd/#uint32): The numeric value of this scalar value as an unsigned 32-bit integer. ### `unsafe_write_utf8` `unsafe_write_utf8[optimize_ascii: Bool = True, branchless: Bool = False](self, ptr: UnsafePointer[Byte, origin, address_space=address_space]) -> Int` Shift unicode to utf8 representation. Safety: `ptr` MUST point to at least `self.utf8_byte_length()` allocated bytes or else an out-of-bounds write will occur, which is undefined behavior. ### Unicode (represented as UInt32 BE) to UTF-8 conversion: * 1: 00000000 00000000 00000000 0aaaaaaa -> 0aaaaaaa * a * 2: 00000000 00000000 00000aaa aabbbbbb -> 110aaaaa 10bbbbbb * (a >> 6) | 0b11000000, b | 0b10000000 * 3: 00000000 00000000 aaaabbbb bbcccccc -> 1110aaaa 10bbbbbb 10cccccc * (a >> 12) | 0b11100000, (b >> 6) | 0b10000000, c | 0b10000000 * 4: 00000000 000aaabb bbbbcccc ccdddddd -> 11110aaa 10bbbbbb 10cccccc 10dddddd * (a >> 18) | 0b11110000, (b >> 12) | 0b10000000, (c >> 6) | 0b10000000, d | 0b10000000 . **Parameters:** * ​optimize\_ascii ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Optimize for languages with mostly ASCII characters. * ​branchless ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Use a branchless algorithm. **Args:** * ​ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer value to write the encoded UTF-8 bytes. Must validly point to a sufficient number of bytes (1-4) to hold the encoded data. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Returns the number of bytes written. ### `utf8_byte_length` `utf8_byte_length(self) -> Int` Returns the number of UTF-8 bytes required to encode this character. Notes: The returned value is always between 1 and 4 bytes. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Byte count of UTF-8 bytes required to encode this character.
--- ## codepoint (Codepoint)
Unicode codepoint handling. This module provides the `Codepoint` type for representing single Unicode scalar values. A codepoint represents a single Unicode character, restricted to valid Unicode scalar values in the ranges 0 to 0xD7FF and 0xE000 to 0x10FFFF inclusive. The `Codepoint` type provides functionality for: * Converting between codepoints and UTF-8 encoded bytes. * Testing character properties like ASCII, digits, whitespace etc. * Converting between codepoints and strings. * Safe construction from integers with validation. Example: ```mojo from collections.string import Codepoint from testing import assert_true # Create a codepoint from a character var c = Codepoint.ord('A') # Check properties assert_true(c.is_ascii()) assert_true(c.is_ascii_upper()) # Convert to string var s = String(c) # "A" ``` ## Structs * [​`Codepoint`](/mojo/stdlib/collections/string/codepoint/Codepoint): A Unicode codepoint, typically a single user-recognizable character; restricted to valid Unicode scalar values.
--- ## format (Format)
String formatting utilities for Mojo. This module provides string formatting functionality similar to Python's `str.format()` method. The `format()` method (available on the [`String`](/mojo/stdlib/collections/string/string/String#format) and [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice#format) types) takes the current string as a template (or "format string"), which can contain literal text and/or replacement fields delimited by curly braces (`{}`). The replacement fields are replaced with the values of the arguments. Replacement fields can mapped to the arguments in one of two ways: * Automatic indexing by argument position: ```mojo var s = "{} is {}".format("Mojo", "🔥") ``` * Manual indexing by argument position: ```mojo var s = "{1} is {0}".format("hot", "🔥") ``` The replacement fields can also contain the `!r` or `!s` conversion flags, to indicate whether the argument should be formatted using `repr()` or `String()`, respectively: ```mojo var s = "{!r}".format(myComplicatedObject) ``` Note that the following features from Python's `str.format()` are **not yet supported**: * Named arguments (for example `"{name} is {adjective}"`). * Accessing the attributes of an argument value (for example, `"{0.name}"`. * Accessing an indexed value from the argument (for example, `"{1[0]}"`). * Format specifiers for controlling output format (width, precision, and so on). Examples: ```mojo # Basic formatting var s1 = "Hello {0}!".format("World") # Hello World! # Multiple arguments var s2 = "{0} plus {1} equals {2}".format(1, 2, 3) # 1 plus 2 equals 3 # Conversion flags var s4 = "{!r}".format("test") # "'test'" ``` This module has no public API; its functionality is available through the [`String.format()`](/mojo/stdlib/collections/string/string/String#format) and [`StringSlice.format()`](/mojo/stdlib/collections/string/string_slice/StringSlice#format) methods.
--- ## string
The string package provides comprehensive Unicode string handling functionality for Mojo. This package implements Unicode-aware string types and operations, with UTF-8 support. It includes efficient implementations for string manipulation, formatting, and Unicode operations while maintaining memory safety and performance. Key Components: * `String`: The main string type supporting UTF-8 encoded text, * `StringSlice`: Memory-efficient string view type for zero-copy operations * `Codepoint`: Unicode code point handling and operations * Format: String formatting and interpolation utilities Core Features: * Unicode support with UTF-8 encoding * Efficient string slicing and views * String formatting and interpolation * Memory-safe string operations * Unicode case conversion * Unicode property lookups and validation Example: ```mojo # Basic string creation and manipulation var s = "Hello, 世界" # runtime type is `String` var slice = s[0:5] # "Hello" # Unicode-aware operations for c in s.codepoints(): if c.is_ascii_lower(): print(String(c).upper()) else: print(c) # String formatting var name = "Mojo" var formatted = "Hello, {name}!" ``` Note: String stores data using UTF-8, and all operations (unless clearly noted) are intended to be fully Unicode compliant and maintain correct UTF-8 encoded data. A handful of operations are known to not be Unicode / UTF-8 compliant yet, but will be fixed as time permits. ## Modules * [​`codepoint`](/mojo/stdlib/collections/string/codepoint/): Unicode codepoint handling. * [​`format`](/mojo/stdlib/collections/string/format/): String formatting utilities for Mojo. * [​`string`](/mojo/stdlib/collections/string/string/): The core `String` type implementation for Mojo. * [​`string_slice`](/mojo/stdlib/collections/string/string_slice/): The `StringSlice` type implementation for efficient string operations.
--- ## String (String)
`struct String` Represents a mutable string. See the [`string` module](/mojo/stdlib/collections/string/string/) for more information and examples. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`Comparable`](/mojo/stdlib/builtin/comparable/Comparable), [`ConvertibleFromPython`](/mojo/stdlib/python/conversions/ConvertibleFromPython), [`ConvertibleToPython`](/mojo/stdlib/python/conversions/ConvertibleToPython), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`FloatableRaising`](/mojo/stdlib/builtin/floatable/FloatableRaising), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`PathLike`](/mojo/stdlib/os/pathlike/PathLike), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable), [`Writer`](/mojo/stdlib/io/write/Writer) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = False` ### `ASCII_LETTERS` `comptime ASCII_LETTERS = String.ASCII_LOWERCASE.__add__["abcdefghijklmnopqrstuvwxyz"](String.ASCII_UPPERCASE)` All ASCII letters (lowercase and uppercase). ### `ASCII_LOWERCASE` `comptime ASCII_LOWERCASE = "abcdefghijklmnopqrstuvwxyz"` All lowercase ASCII letters. ### `ASCII_UPPERCASE` `comptime ASCII_UPPERCASE = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"` All uppercase ASCII letters. ### `DIGITS` `comptime DIGITS = "0123456789"` All decimal digit characters. ### `FLAG_HAS_NUL_TERMINATOR` `comptime FLAG_HAS_NUL_TERMINATOR = (1 << (Int.BITWIDTH - 3))` Flag indicating string has accessible nul terminator. ### `FLAG_IS_INLINE` `comptime FLAG_IS_INLINE = (1 << (Int.BITWIDTH - 1))` Flag indicating string uses inline (SSO) storage. ### `FLAG_IS_REF_COUNTED` `comptime FLAG_IS_REF_COUNTED = (1 << (Int.BITWIDTH - 2))` Flag indicating string uses reference-counted storage. ### `HEX_DIGITS` `comptime HEX_DIGITS = String.DIGITS.__add__["0123456789"]("abcdef").__add__["0123456789abcdef"]("ABCDEF")` All hexadecimal digit characters. ### `INLINE_CAPACITY` `comptime INLINE_CAPACITY = (((Int.BITWIDTH // 8) * 3) - 1)` Maximum bytes for inline (SSO) string storage. ### `INLINE_LENGTH_MASK` `comptime INLINE_LENGTH_MASK = (31 << String.INLINE_LENGTH_START)` Bit mask for extracting inline string length. ### `INLINE_LENGTH_START` `comptime INLINE_LENGTH_START = (Int.BITWIDTH - 8)` Bit position where inline length field starts. ### `OCT_DIGITS` `comptime OCT_DIGITS = "01234567"` All octal digit characters. ### `PRINTABLE` `comptime PRINTABLE = String.DIGITS.__add__["0123456789"](String.ASCII_LETTERS).__add__["0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"](String.PUNCTUATION).__add__["0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\22#$%&'()*+,-./:;<=>?@[\\]^_`{|}\~"]\(" \t\n\r\v\f")\` All printable ASCII characters. ### `PUNCTUATION` `comptime PUNCTUATION = "!\22#$%&'()*+,-./:;<=>?@[\\]^_`{|}\~"\` All ASCII punctuation characters. ### `REF_COUNT_SIZE` `comptime REF_COUNT_SIZE = size_of[Atomic[DType.index]]()` Size of the reference count prefix for heap strings. ## Methods ### `__init__` `__init__(out self)` Construct an empty string. `__init__(out self, *, capacity: Int)` Construct an empty string with a given capacity. **Args:** * ​capacity ([`Int`](/mojo/stdlib/builtin/int/Int)): The capacity of the string to allocate. `@implicit` `__init__(out self, data: StringSlice[StaticConstantOrigin])` Construct a `String` from a `StaticString` without allocating. **Args:** * ​data ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The static constant string to refer to. `@implicit` `__init__(out self, data: StringLiteral[value])` Construct a `String` from a `StringLiteral` without allocating. **Args:** * ​data ([`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral)): The static constant string to refer to. `__init__(out self, *, bytes: Span[Byte, origin])` Construct a string by copying the data. This constructor is explicit because it can involve memory allocation. **Args:** * ​bytes ([`Span`](/mojo/stdlib/memory/span/Span)): The bytes to copy. `__init__[T: Stringable](out self, value: T)` Initialize from a type conforming to `Stringable`. **Parameters:** * ​T ([`Stringable`](/mojo/stdlib/builtin/str/Stringable)): The type conforming to Stringable. **Args:** * ​value (`T`): The object to get the string representation of. `__init__[T: StringableRaising](out self, value: T)` Initialize from a type conforming to `StringableRaising`. **Parameters:** * ​T ([`StringableRaising`](/mojo/stdlib/builtin/str/StringableRaising)): The type conforming to Stringable. **Args:** * ​value (`T`): The object to get the string representation of. **Raises:** If there is an error when computing the string representation of the type. `__init__[*Ts: Writable](out self, *args: *Ts, *, sep: StringSlice[StaticConstantOrigin] = "", end: StringSlice[StaticConstantOrigin] = "")` Construct a string by concatenating a sequence of Writable arguments. Examples: Construct a String from several `Writable` arguments: ```mojo var string = String(1, 2.0, "three", sep=", ") print(string) # "1, 2.0, three" ``` **Parameters:** * ​\*Ts ([`Writable`](/mojo/stdlib/io/write/Writable)): Types of the provided argument sequence. **Args:** * ​\*args (`*Ts`): A sequence of Writable arguments. * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The separator used between elements. * ​end ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The String to write after printing the elements. `__init__[*Ts: Writable](out self, args: VariadicPack[is_owned, origin, Writable, Ts], sep: StringSlice[StaticConstantOrigin] = "", end: StringSlice[StaticConstantOrigin] = "")` Construct a string by passing a variadic pack. Examples: ```mojo fn variadic_pack_to_string[ *Ts: Writable, ](*args: *Ts) -> String: return String(args) string = variadic_pack_to_string(1, ", ", 2.0, ", ", "three") ``` **Parameters:** * ​\*Ts ([`Writable`](/mojo/stdlib/io/write/Writable)): Types of the provided argument sequence. **Args:** * ​args ([`VariadicPack`](/mojo/stdlib/builtin/variadics/VariadicPack)): A VariadicPack of Writable arguments. * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The separator used between elements. * ​end ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The String to write after printing the elements. `__init__(out self, *, unsafe_uninit_length: Int)` Construct a String with the specified length, with uninitialized memory. This is unsafe, as it relies on the caller initializing the elements with unsafe operations, not assigning over the uninitialized data. **Args:** * ​unsafe\_uninit\_length ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of bytes to allocate. `__init__(out self, *, unsafe_from_utf8_ptr: UnsafePointer[c_char, origin])` Creates a string from a UTF-8 encoded nul-terminated pointer. Safety: * `unsafe_from_utf8_ptr` MUST be valid UTF-8 encoded data. * `unsafe_from_utf8_ptr` MUST be null terminated. **Args:** * ​unsafe\_from\_utf8\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): An `UnsafePointer[Byte]` of null-terminated bytes encoded in UTF-8. `__init__(out self, *, unsafe_from_utf8_ptr: UnsafePointer[UInt8, origin])` Creates a string from a UTF-8 encoded nul-terminated pointer. Safety: * `unsafe_from_utf8_ptr` MUST be valid UTF-8 encoded data. * `unsafe_from_utf8_ptr` MUST be null terminated. **Args:** * ​unsafe\_from\_utf8\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): An `UnsafePointer[Byte]` of null-terminated bytes encoded in UTF-8. `__init__(out self, obj: PythonObject)` Construct a `String` from a PythonObject. **Args:** * ​obj ([`PythonObject`](/mojo/stdlib/python/python_object/PythonObject)): The PythonObject to convert from. **Raises:** An error if the conversion failed. ### `__copyinit__` `__copyinit__(out self, other: Self)` Copy initialize the string from another string. **Args:** * ​other (`Self`): The string to copy. ### `__moveinit__` `__moveinit__(out self, deinit other: Self)` Move initialize the string from another string. **Args:** * ​other (`Self`): The string to move. ### `__del__` `__del__(deinit self)` Destroy the string data. ### `__bool__` `__bool__(self) -> Bool` Checks if the string is not empty. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the string length is greater than zero, and False otherwise. ### `__getitem__` `__getitem__[I: Indexer, //](self, idx: I) -> StringSlice[self]` Gets the character at the specified position. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): A type that can be used as an index. **Args:** * ​idx (`I`): The index value. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A StringSlice view containing the character at the specified position. `__getitem__(self, span: ContiguousSlice) -> StringSlice[self]` Gets the sequence of characters at the specified positions. **Args:** * ​span ([`ContiguousSlice`](/mojo/stdlib/builtin/builtin_slice/ContiguousSlice)): A slice that specifies positions of the new substring. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A new string containing the string at the specified positions. ### `__lt__` `__lt__(self, rhs: Self) -> Bool` Compare this String to the RHS using LT comparison. **Args:** * ​rhs (`Self`): The other String to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this String is strictly less than the RHS String and False otherwise. ### `__eq__` `__eq__(self, rhs: Self) -> Bool` Compares two Strings if they have the same values. **Args:** * ​rhs (`Self`): The rhs of the operation. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the Strings are equal and False otherwise. `__eq__(self, other: StringSlice[origin]) -> Bool` Compares two Strings if they have the same values. **Args:** * ​other ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The rhs of the operation. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the Strings are equal and False otherwise. ### `__ne__` `__ne__(self, other: StringSlice[origin]) -> Bool` Compares two Strings if they have the same values. **Args:** * ​other ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The rhs of the operation. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the Strings are equal and False otherwise. ### `__contains__` `__contains__(self, substr: StringSlice[origin]) -> Bool` Returns True if the substring is contained within the current string. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the string contains the substring. ### `__add__` `__add__(self, other: StringSlice[origin]) -> Self` Creates a string by appending a string slice at the end. **Args:** * ​other ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string slice to append. **Returns:** `Self`: The new constructed string. ### `__mul__` `__mul__(self, n: Int) -> Self` Concatenates the string `n` times. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of times to concatenate the string. **Returns:** `Self`: The string concatenated `n` times. ### `__radd__` `__radd__(self, other: StringSlice[origin]) -> Self` Creates a string by prepending another string slice to the start. **Args:** * ​other ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to prepend. **Returns:** `Self`: The new constructed string. ### `__iadd__` `__iadd__(mut self, other: StringSlice[origin])` Appends another string slice to this string. **Args:** * ​other ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to append. ### `write` `static write[*Ts: Writable](*args: *Ts, *, sep: StringSlice[StaticConstantOrigin] = "", end: StringSlice[StaticConstantOrigin] = "") -> Self` Construct a string by concatenating a sequence of Writable arguments. **Parameters:** * ​\*Ts ([`Writable`](/mojo/stdlib/io/write/Writable)): Types of the provided argument sequence. **Args:** * ​\*args (`*Ts`): A sequence of Writable arguments. * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The separator used between elements. * ​end ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The String to write after printing the elements. **Returns:** `Self`: A string formed by formatting the argument sequence. `write[*Ts: Writable](mut self, *args: *Ts)` Write a sequence of Writable arguments to the provided Writer. **Parameters:** * ​\*Ts ([`Writable`](/mojo/stdlib/io/write/Writable)): Types of the provided argument sequence. **Args:** * ​\*args (`*Ts`): Sequence of arguments to write to this Writer. `write[T: Writable](mut self, value: T)` Write a single Writable argument to the provided Writer. **Parameters:** * ​T ([`Writable`](/mojo/stdlib/io/write/Writable)): The type of the value to write, which must implement `Writable`. **Args:** * ​value (`T`): The `Writable` argument to write. `static write[T: Writable](value: T) -> Self` Write a single Writable argument to the provided Writer. **Parameters:** * ​T ([`Writable`](/mojo/stdlib/io/write/Writable)): The type of the value to write, which must implement `Writable`. **Args:** * ​value (`T`): The `Writable` argument to write. **Returns:** `Self`: A new `String` containing the written value. ### `capacity` `capacity(self) -> Int` Get the current capacity of the `String`'s internal buffer. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of bytes that can be stored before reallocation is needed. ### `write_bytes` `write_bytes(mut self, bytes: Span[Byte, origin])` Write a byte span to this String. **Args:** * ​bytes ([`Span`](/mojo/stdlib/memory/span/Span)): The byte span to write to this String. Must NOT be null terminated. ### `append_byte` `append_byte(mut self, byte: UInt8)` Append a byte to the string. **Args:** * ​byte ([`UInt8`](/mojo/stdlib/builtin/simd/#uint8)): The byte to append. ### `__iter__` `__iter__(self) -> CodepointSliceIter[self]` Iterate over the string, returning immutable references. **Returns:** `CodepointSliceIter`: An iterator of references to the string elements. ### `__reversed__` `__reversed__(self) -> CodepointSliceIter[self, False]` Iterate backwards over the string, returning immutable references. **Returns:** `CodepointSliceIter`: A reversed iterator of references to the string elements. ### `__len__` `__len__(self) -> Int` Get the string length of in bytes. This function returns the number of bytes in the underlying UTF-8 representation of the string. To get the number of Unicode codepoints in a string, use `len(str.codepoints())`. # Examples Query the length of a string, in bytes and Unicode codepoints: ```mojo from testing import assert_equal var s = "ನಮಸ್ಕಾರ" assert_equal(len(s), 21) assert_equal(len(s.codepoints()), 7) ``` Strings containing only ASCII characters have the same byte and Unicode codepoint length: ```mojo from testing import assert_equal var s = "abc" assert_equal(len(s), 3) assert_equal(len(s.codepoints()), 3) ``` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The string length in bytes. ### `__str__` `__str__(self) -> Self` Gets the string itself. This method ensures that you can pass a `String` to a method that takes a `Stringable` value. **Returns:** `Self`: The string itself. ### `__repr__` `__repr__(self) -> Self` Return a Mojo-compatible representation of the `String` instance. **Returns:** `Self`: A new representation of the string. ### `__fspath__` `__fspath__(self) -> Self` Return the file system path representation (just the string itself). **Returns:** `Self`: The file system path representation as a string. ### `to_python_object` `to_python_object(var self) -> PythonObject` Convert this value to a PythonObject. **Returns:** [`PythonObject`](/mojo/stdlib/python/python_object/PythonObject): A PythonObject representing the value. **Raises:** If the operation fails. ### `write_to` `write_to(self, mut writer: T)` Formats this string to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `join` `join[T: Copyable & Writable](self, elems: Span[T, origin]) -> Self` Joins string elements using the current string as a delimiter. Defaults to writing to the stack if total bytes of `elems` is less than `buffer_size`, otherwise will allocate once to the heap and write directly into that. The `buffer_size` defaults to 4096 bytes to match the default page size on arm64 and x86-64. Notes: * Defaults to writing directly to the string if the bytes fit in an inline `String`, otherwise will process it by chunks. * The `buffer_size` defaults to 4096 bytes to match the default page size on arm64 and x86-64, but you can increase this if you're joining a very large `List` of elements to write into the stack instead of the heap. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): The type of the elements. Must implement the `Copyable`, and `Writable` traits. **Args:** * ​elems ([`Span`](/mojo/stdlib/memory/span/Span)): The input values. **Returns:** `Self`: The joined string. ### `codepoints` `codepoints(self) -> CodepointsIter[self]` Returns an iterator over the `Codepoint`s encoded in this string slice. # Examples Print the characters in a string: ```mojo from testing import assert_equal var s = "abc" var iter = s.codepoints() assert_equal(iter.__next__(), Codepoint.ord("a")) assert_equal(iter.__next__(), Codepoint.ord("b")) assert_equal(iter.__next__(), Codepoint.ord("c")) assert_equal(iter.__has_next__(), False) ``` `codepoints()` iterates over Unicode codepoints, and supports multibyte codepoints: ```mojo from testing import assert_equal # A visual character composed of a combining sequence of 2 codepoints. var s = "á" assert_equal(s.byte_length(), 3) var iter = s.codepoints() assert_equal(iter.__next__(), Codepoint.ord("a")) # U+0301 Combining Acute Accent assert_equal(iter.__next__().to_u32(), 0x0301) assert_equal(iter.__has_next__(), False) ``` **Returns:** `CodepointsIter`: An iterator type that returns successive `Codepoint` values stored in this string slice. ### `codepoint_slices` `codepoint_slices(self) -> CodepointSliceIter[self]` Returns an iterator over single-character slices of this string. Each returned slice points to a single Unicode codepoint encoded in the underlying UTF-8 representation of this string. # Examples Iterate over the character slices in a string: ```mojo from testing import assert_equal, assert_true var s = "abc" var iter = s.codepoint_slices() assert_true(iter.__next__() == "a") assert_true(iter.__next__() == "b") assert_true(iter.__next__() == "c") assert_equal(iter.__has_next__(), False) ``` **Returns:** `CodepointSliceIter`: An iterator of references to the string elements. ### `unsafe_ptr` `unsafe_ptr(self) -> UnsafePointer[Byte, self]` Retrieves a pointer to the underlying memory. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The pointer to the underlying memory. ### `unsafe_ptr_mut` `unsafe_ptr_mut(mut self, var capacity: Int = 0) -> UnsafePointer[Byte, self]` Retrieves a mutable pointer to the unique underlying memory. Passing a larger capacity will reallocate the string to the new capacity if larger than the existing capacity, allowing you to write more data. **Args:** * ​capacity ([`Int`](/mojo/stdlib/builtin/int/Int)): The new capacity of the string. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The pointer to the underlying memory. ### `as_c_string_slice` `as_c_string_slice(mut self) -> CStringSlice[origin_of((muttoimm self))]` Return a `CStringSlice` to the underlying memory of the string. **Returns:** `CStringSlice`: The `CStringSlice` of the string. ### `unsafe_cstr_ptr` `unsafe_cstr_ptr(mut self) -> UnsafePointer[c_char, origin_of((muttoimm self))]` Retrieves a C-string-compatible pointer to the underlying memory. The returned pointer is guaranteed to be null, or NUL terminated. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The pointer to the underlying memory. ### `as_bytes` `as_bytes(self) -> Span[Byte, self]` Returns a contiguous slice of the bytes owned by this string. **Returns:** [`Span`](/mojo/stdlib/memory/span/Span): A contiguous slice pointing to the bytes owned by this string. ### `as_bytes_mut` `as_bytes_mut(mut self) -> Span[Byte, self]` Returns a mutable contiguous slice of the bytes owned by this string. This name has a \_mut suffix so the as\_bytes() method doesn't have to guarantee mutability. **Returns:** [`Span`](/mojo/stdlib/memory/span/Span): A contiguous slice pointing to the bytes owned by this string. ### `as_string_slice` `as_string_slice(self) -> StringSlice[self]` Returns a string slice of the data owned by this string. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A string slice pointing to the data owned by this string. ### `as_string_slice_mut` `as_string_slice_mut(mut self) -> StringSlice[self]` Returns a mutable string slice of the data owned by this string. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A string slice pointing to the data owned by this string. ### `byte_length` `byte_length(self) -> Int` Get the string length in bytes. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of this string in bytes. ### `set_byte_length` `set_byte_length(mut self, new_len: Int)` Set the byte length of the `String`. This is an internal helper method that updates the length field. **Args:** * ​new\_len ([`Int`](/mojo/stdlib/builtin/int/Int)): The new byte length to set. ### `count` `count(self, substr: StringSlice[origin]) -> Int` Return the number of non-overlapping occurrences of substring `substr` in the string. If sub is empty, returns the number of empty strings between characters which is the length of the string plus one. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to count. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of occurrences of `substr`. ### `find` `find(self, substr: StringSlice[origin], start: Int = 0) -> Int` Finds the offset of the first occurrence of `substr` starting at `start`. If not found, returns -1. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to find. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset from which to find. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The offset of `substr` relative to the beginning of the string. ### `rfind` `rfind(self, substr: StringSlice[origin], start: Int = 0) -> Int` Finds the offset of the last occurrence of `substr` starting at `start`. If not found, returns -1. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to find. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset from which to find. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The offset of `substr` relative to the beginning of the string. ### `isspace` `isspace(self) -> Bool` Determines whether every character in the given String is a python whitespace String. This corresponds to Python's [universal separators](https://docs.python.org/3/library/stdtypes.html#str.splitlines) `" \t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029"`. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the whole String is made up of whitespace characters listed above, otherwise False. ### `split` `split(self, sep: StringSlice[origin]) -> List[StringSlice[self]]` Split the string by a separator. Examples: ```mojo # Splitting a space _ = StringSlice("hello world").split(" ") # ["hello", "world"] # Splitting adjacent separators _ = StringSlice("hello,,world").split(",") # ["hello", "", "world"] # Splitting with starting or ending separators _ = StringSlice(",1,2,3,").split(",") # ['', '1', '2', '3', ''] # Splitting with an empty separator _ = StringSlice("123").split("") # ['', '1', '2', '3', ''] ``` **Args:** * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to split on. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: StringSlice[origin], maxsplit: Int) -> List[StringSlice[self]]` Split the string by a separator. Examples: ```mojo # Splitting with maxsplit _ = StringSlice("1,2,3").split(",", maxsplit=1) # ['1', '2,3'] # Splitting with starting or ending separators _ = StringSlice(",1,2,3,").split(",", maxsplit=1) # ['', '1,2,3,'] # Splitting with an empty separator _ = StringSlice("123").split("", maxsplit=1) # ['', '123'] ``` **Args:** * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to split on. * ​maxsplit ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum amount of items to split from String. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: NoneType = None) -> List[StringSlice[self]]` Split the string by every Whitespace separator. Examples: ```mojo # Splitting an empty string or filled with whitespaces _ = StringSlice(" ").split() # [] _ = StringSlice("").split() # [] # Splitting a string with leading, trailing, and middle whitespaces _ = StringSlice(" hello world ").split() # ["hello", "world"] # Splitting adjacent universal newlines: _ = StringSlice( "hello \t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029world" ).split() # ["hello", "world"] ``` **Args:** * ​sep ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): None. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: NoneType = None, *, maxsplit: Int) -> List[StringSlice[self]]` Split the string by every Whitespace separator. Examples: ```mojo # Splitting with maxsplit _ = StringSlice("1 2 3").split(maxsplit=1) # ['1', '2 3'] ``` **Args:** * ​sep ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): None. * ​maxsplit ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum amount of items to split from String. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. ### `splitlines` `splitlines(self, keepends: Bool = False) -> List[StringSlice[self]]` Split the string at line boundaries. This corresponds to Python's [universal newlines:](https://docs.python.org/3/library/stdtypes.html#str.splitlines) `"\r\n"` and `"\t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029"`. **Args:** * ​keepends ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, line breaks are kept in the resulting strings. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by line boundaries. ### `replace` `replace(self, old: StringSlice[origin], new: StringSlice[origin]) -> Self` Return a copy of the string with all occurrences of substring `old` if replaced by `new`. **Args:** * ​old ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to replace. * ​new ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to replace with. **Returns:** `Self`: The string where all occurrences of `old` are replaced with `new`. ### `strip` `strip(self, chars: StringSlice[origin]) -> StringSlice[self]` Return a copy of the string with leading and trailing characters removed. **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A copy of the string with no leading or trailing characters. `strip(self) -> StringSlice[self]` Return a copy of the string with leading and trailing whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A copy of the string with no leading or trailing whitespaces. ### `rstrip` `rstrip(self, chars: StringSlice[origin]) -> StringSlice[self]` Return a copy of the string with trailing characters removed. **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A copy of the string with no trailing characters. `rstrip(self) -> StringSlice[self]` Return a copy of the string with trailing whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A copy of the string with no trailing whitespaces. ### `lstrip` `lstrip(self, chars: StringSlice[origin]) -> StringSlice[self]` Return a copy of the string with leading characters removed. **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A copy of the string with no leading characters. `lstrip(self) -> StringSlice[self]` Return a copy of the string with leading whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A copy of the string with no leading whitespaces. ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with the underlying bytes. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance. ### `lower` `lower(self) -> Self` Returns a copy of the string with all cased characters converted to lowercase. **Returns:** `Self`: A new string where cased letters have been converted to lowercase. ### `upper` `upper(self) -> Self` Returns a copy of the string with all cased characters converted to uppercase. **Returns:** `Self`: A new string where cased letters have been converted to uppercase. ### `startswith` `startswith(self, prefix: StringSlice[origin], start: Int = 0, end: Int = -1) -> Bool` Checks if the string starts with the specified prefix between start and end positions. Returns True if found and False otherwise. **Args:** * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix to check. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The start offset from which to check. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end offset from which to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `self[start:end]` is prefixed by the input prefix. ### `endswith` `endswith(self, suffix: StringSlice[origin], start: Int = 0, end: Int = -1) -> Bool` Checks if the string end with the specified suffix between start and end positions. Returns True if found and False otherwise. **Args:** * ​suffix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The suffix to check. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The start offset from which to check. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end offset from which to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `self[start:end]` is suffixed by the input suffix. ### `removeprefix` `removeprefix(self, prefix: StringSlice[origin], /) -> StringSlice[self]` Returns a new string with the prefix removed if it was present. Examples: ```mojo print(String('TestHook').removeprefix('Test')) # 'Hook' print(String('BaseTestCase').removeprefix('Test')) # 'BaseTestCase' ``` **Args:** * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix to remove from the string. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): `string[len(prefix):]` if the string starts with the prefix string, or a copy of the original string otherwise. ### `removesuffix` `removesuffix(self, suffix: StringSlice[origin], /) -> StringSlice[self]` Returns a new string with the suffix removed if it was present. Examples: ```mojo print(String('TestHook').removesuffix('Hook')) # 'Test' print(String('BaseTestCase').removesuffix('Test')) # 'BaseTestCase' ``` **Args:** * ​suffix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The suffix to remove from the string. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): `string[:-len(suffix)]` if the string ends with the suffix string, or a copy of the original string otherwise. ### `__int__` `__int__(self) -> Int` Parses the given string as a base-10 integer and returns that value. If the string cannot be parsed as an int, an error is raised. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer value that represents the string, or otherwise raises. **Raises:** If the operation fails. ### `__float__` `__float__(self) -> Float64` Parses the string as a float point number and returns that value. If the string cannot be parsed as a float, an error is raised. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): A float value that represents the string, or otherwise raises. **Raises:** If the operation fails. ### `format` `format[*Ts: Stringable & Representable](self, *args: *Ts) -> Self` Produce a formatted string using the current string as a template. The template, or "format string" can contain literal text and/or replacement fields delimited with curly braces (`{}`). Returns a copy of the format string with the replacement fields replaced with string representations of the `args` arguments. For more information, see the discussion in the [`format` module](/mojo/stdlib/collections/string/format/). Example: ```mojo # Manual indexing: print("{0} {1} {0}".format("Mojo", 1.125)) # Mojo 1.125 Mojo # Automatic indexing: print("{} {}".format(True, "hello world")) # True hello world ``` **Parameters:** * ​\*Ts ([`Stringable`](/mojo/stdlib/builtin/str/Stringable) & [`Representable`](/mojo/stdlib/builtin/repr/Representable)): The types of substitution values that implement `Representable` and `Stringable` (to be changed and made more flexible). **Args:** * ​\*args (`*Ts`): The substitution values. **Returns:** `Self`: The template with the given values substituted. **Raises:** If the operation fails. ### `isdigit` `isdigit(self) -> Bool` A string is a digit string if all characters in the string are digits and there is at least one character in the string. Note that this currently only works with ASCII strings. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all characters are digits and it's not empty else False. ### `isupper` `isupper(self) -> Bool` Returns True if all cased characters in the string are uppercase and there is at least one cased character. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all cased characters in the string are uppercase and there is at least one cased character, False otherwise. ### `islower` `islower(self) -> Bool` Returns True if all cased characters in the string are lowercase and there is at least one cased character. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all cased characters in the string are lowercase and there is at least one cased character, False otherwise. ### `isprintable` `isprintable(self) -> Bool` Returns True if all characters in the string are ASCII printable. Note that this currently only works with ASCII strings. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all characters are printable else False. ### `rjust` `rjust(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> Self` Returns the string right justified in a string of specified width. Pads the string on the left with the specified fill character so that the total length of the resulting string equals `width`. If the original string is already longer than or equal to `width`, returns the original string unchanged. Examples: ```mojo var s = String("hello") print(s.rjust(10)) # " hello" print(s.rjust(10, "*")) # "*****hello" print(s.rjust(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** `Self`: A right-justified string of length `width`, or the original string if its length is already greater than or equal to `width`. ### `ljust` `ljust(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> Self` Returns the string left justified in a string of specified width. Pads the string on the right with the specified fill character so that the total length of the resulting string equals `width`. If the original string is already longer than or equal to `width`, returns the original string unchanged. Examples: ```mojo var s = String("hello") print(s.ljust(10)) # "hello " print(s.ljust(10, "*")) # "hello*****" print(s.ljust(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** `Self`: A left-justified string of length `width`, or the original string if its length is already greater than or equal to `width`. ### `center` `center(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> Self` Returns the string center justified in a string of specified width. Pads the string on both sides with the specified fill character so that the total length of the resulting string equals `width`. If the padding needed is odd, the extra character goes on the right side. If the original string is already longer than or equal to `width`, returns the original string unchanged. Examples: ```mojo var s = String("hello") print(s.center(10)) # " hello " print(s.center(11, "*")) # "***hello***" print(s.center(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** `Self`: A center-justified string of length `width`, or the original string if its length is already greater than or equal to `width`. ### `resize` `resize(mut self, length: Int, fill_byte: UInt8 = 0)` Resize the string to a new length. Notes: If the new length is greater than the current length, the string is extended by the difference, and the new bytes are initialized to `fill_byte`. **Args:** * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The new length of the string. * ​fill\_byte ([`UInt8`](/mojo/stdlib/builtin/simd/#uint8)): The byte to fill any new space with. `resize(mut self, *, unsafe_uninit_length: Int)` Resizes the string to the given new size leaving any new data uninitialized. If the new size is smaller than the current one, elements at the end are discarded. If the new size is larger than the current one, the string is extended and the new data is left uninitialized. **Args:** * ​unsafe\_uninit\_length ([`Int`](/mojo/stdlib/builtin/int/Int)): The new size. ### `reserve` `reserve(mut self, new_capacity: Int)` Reserves the requested capacity. Notes: If the current capacity is greater or equal, this is a no-op. Otherwise, the storage is reallocated and the data is moved. **Args:** * ​new\_capacity ([`Int`](/mojo/stdlib/builtin/int/Int)): The new capacity in stored bytes.
--- ## ascii
`ascii(value: StringSlice[origin]) -> String` Get the ASCII representation of the object. **Args:** * ​value ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The object to get the ASCII representation of. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the ASCII representation of the object.
--- ## atof
`atof(str_slice: StringSlice[origin]) -> Float64` Parses the given string as a floating point and returns that value. For example, `atof("2.25")` returns `2.25`. This function is in the prelude, so you don't need to import it. **Args:** * ​str\_slice ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A string to be parsed as a floating point. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): A floating-point value that represents the string. **Raises:** If the given string cannot be parsed as an floating-point value, for example in `atof("hi")`.
--- ## atol
`atol(str_slice: StringSlice[origin], base: Int = 10) -> Int` Parses and returns the given string as an integer in the given base. If base is set to 0, the string is parsed as an integer literal, with the following considerations: * '0b' or '0B' prefix indicates binary (base 2) * '0o' or '0O' prefix indicates octal (base 8) * '0x' or '0X' prefix indicates hexadecimal (base 16) * Without a prefix, it's treated as decimal (base 10) This follows [Python's integer literals format](https://docs.python.org/3/reference/lexical_analysis.html#integers). This function is in the prelude, so you don't need to import it. Examples: ```text >>> atol("32") 32 >>> atol("FF", 16) 255 >>> atol("0xFF", 0) 255 >>> atol("0b1010", 0) 10 ``` **Args:** * ​str\_slice ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A string to be parsed as an integer in the given base. * ​base ([`Int`](/mojo/stdlib/builtin/int/Int)): Base used for conversion, value must be between 2 and 36, or 0. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer value that represents the string. **Raises:** If the given string cannot be parsed as an integer value or if an incorrect base is provided.
--- ## chr
`chr(c: Int) -> String` Returns a String based on the given Unicode code point. This is the inverse of the `ord()` function. This function is in the prelude, so you don't need to import it. Example: ```mojo print(chr(97), chr(8364)) # "a €" ``` **Args:** * ​c ([`Int`](/mojo/stdlib/builtin/int/Int)): An integer that represents a code point. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing a single character based on the given code point.
--- ## string (3)
The core `String` type implementation for Mojo. This module provides the primary `String` type and its fundamental operations. The `String` type is a mutable string, and is designed to handle UTF-8 encoded text efficiently while providing a safe and ergonomic interface for string manipulation. Related types: * [`StringSlice`](/mojo/stdlib/collections/string/string_slice/). A non-owning view of string data, which can be either mutable or immutable. * [`StaticString`](/mojo/stdlib/collections/string/string_slice/#comptime-values). A `comptime` type alias for an immutable constant `StringSlice`. * [`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral/). A string literal. String literals are compile-time values. For use at runtime, you usually want wrap a `StringLiteral` in a `String` (for a mutable string) or `StaticString` (for an immutable constant string). Key Features: * Short string optimization (SSO) and lazy copying of constant string data. * O(1) copy operation. * Memory-safe string operations. * Efficient string concatenation and slicing. * String-to-number conversions ( [`atof()`](/mojo/stdlib/collections/string/string/atof), [`atol()`](/mojo/stdlib/collections/string/string/atol)). * Character code conversions ( [`chr()`](/mojo/stdlib/collections/string/string/chr), [`ord()`](/mojo/stdlib/collections/string/string/ord)). * String formatting with [`format()`](/mojo/stdlib/collections/string/string/String/#format). The `String` type has Unicode support through UTF-8 encoding. A handful of operations are known to not be Unicode / UTF-8 compliant yet, but will be fixed as time permits. This type is in the prelude, so it is automatically imported into every Mojo program. Example: ```mojo # String creation and basic operations var s1 = "Hello" var s2 = "World" var combined = s1 + " " + s2 # "Hello World" # String-to-number conversion var num = atof("3.14") var int_val = atol("42") # Character operations var char = chr(65) # "A" var code = ord("A") # 65 # String formatting print("Codepoint {} is {}".format(code, char)) # Codepoint 65 is A # ASCII utilities var ascii_str = ascii("Hello") # ASCII-only string ``` ## Structs * [​`String`](/mojo/stdlib/collections/string/string/String): Represents a mutable string. ## Functions * [​`ascii`](/mojo/stdlib/collections/string/string/ascii): Get the ASCII representation of the object. * [​`atof`](/mojo/stdlib/collections/string/string/atof): Parses the given string as a floating point and returns that value. * [​`atol`](/mojo/stdlib/collections/string/string/atol): Parses and returns the given string as an integer in the given base. * [​`chr`](/mojo/stdlib/collections/string/string/chr): Returns a String based on the given Unicode code point. This is the inverse of the `ord()` function. * [​`ord`](/mojo/stdlib/collections/string/string/ord): Returns an integer that represents the codepoint of a single-character string.
--- ## ord
`ord(s: StringSlice[origin]) -> Int` Returns an integer that represents the codepoint of a single-character string. Given a string containing a single character `Codepoint`, return an integer representing the codepoint of that character. For example, `ord("a")` returns the integer `97`. This is the inverse of the `chr()` function. This function is in the prelude, so you don't need to import it. **Args:** * ​s ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The input string, which must contain only a single- character. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer representing the code point of the given character.
--- ## CodepointSliceIter
`struct CodepointSliceIter[mut: Bool, //, origin: Origin[mut], forward: Bool = True]` Iterator for `StringSlice` over substring slices containing a single Unicode codepoint. The `forward` parameter only controls the behavior of the `__next__()` method used for normal iteration. Calls to `next()` will always take an element from the front of the iterator, and calls to `next_back()` will always take an element from the end. ## Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the slice is mutable. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the underlying string data. * ​forward ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): The iteration direction. `False` is backwards. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Iterator`](/mojo/stdlib/iter/Iterator), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Element` `comptime Element = StringSlice[origin]` The element type yielded by iteration. ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = CodepointSliceIter[origin, forward]` The iterator type for this codepoint iterator. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__iter__` `__iter__(ref self) -> Self` Iterate over the `StringSlice` yielding individual characters. **Returns:** `Self`: An iterator over the characters in the string slice. ### `__has_next__` `__has_next__(self) -> Bool` Returns True if there are still elements in this iterator. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): A boolean indicating if there are still elements in this iterator. ### `__next__` `__next__(mut self) -> StringSlice[origin]` Get the next codepoint in the underlying string slice. This returns the next single-codepoint substring slice encoded in the underlying string, and advances the iterator state. If `forward` is set to `False`, this will return the next codepoint from the end of the string. This function will abort if this iterator has been exhausted. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): The next character in the string. ### `__len__` `__len__(self) -> Int` Returns the remaining length of this iterator in `Codepoint`s. The value returned from this method indicates the number of subsequent calls to `next()` that will return a value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Number of codepoints remaining in this iterator. ### `peek_next` `peek_next(self) -> Optional[StringSlice[origin]]` Check what the next single-codepoint slice in this iterator is, without advancing the iterator state. Repeated calls to this method will return the same value. # Examples `peek_next()` does not advance the iterator, so repeated calls will return the same value: ```mojo from collections.string import Codepoint from testing import assert_equal var input = StringSlice("123") var iter = input.codepoint_slices() assert_equal(iter.peek_next().value(), "1") assert_equal(iter.peek_next().value(), "1") assert_equal(iter.peek_next().value(), "1") # A call to `next()` return the same value as `peek_next()` had, # but also advance the iterator. assert_equal(iter.next().value(), "1") # Later `peek_next()` calls will return the _new_ next character: assert_equal(iter.peek_next().value(), "2") ``` **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): The next codepoint slice in the underlying string, or None if the string is empty. ### `peek_back` `peek_back(mut self) -> Optional[StringSlice[origin]]` Check what the last single-codepoint slice in this iterator is, without advancing the iterator state. Repeated calls to this method will return the same value. # Examples `peek_back()` does not advance the iterator, so repeated calls will return the same value: ```mojo from collections.string import Codepoint from testing import assert_equal var input = StringSlice("123") var iter = input.codepoint_slices() # Repeated calls to `peek_back()` return the same value. assert_equal(iter.peek_back().value(), "3") assert_equal(iter.peek_back().value(), "3") assert_equal(iter.peek_back().value(), "3") # A call to `next_back()` return the same value as `peek_back()` had, # but also advance the iterator. assert_equal(iter.next_back().value(), "3") # Later `peek_back()` calls will return the _new_ next character: assert_equal(iter.peek_back().value(), "2") ``` **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): The last codepoint slice in the underlying string, or None if the string is empty. ### `next` `next(mut self) -> Optional[StringSlice[origin]]` Get the next codepoint slice in the underlying string slice, or None if the iterator is empty. This returns the next single-codepoint substring encoded in the underlying string, and advances the iterator state. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): A character if the string is not empty, otherwise None. ### `next_back` `next_back(mut self) -> Optional[StringSlice[origin]]` Get the last single-codepoint slice in this iterator is, or None if the iterator is empty. This returns the last codepoint slice in this iterator, and advances the iterator state. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): The last codepoint slice in the underlying string, or None if the string is empty.
--- ## CodepointsIter
`struct CodepointsIter[mut: Bool, //, origin: Origin[mut]]` Iterator over the `Codepoint`s in a string slice, constructed by `StringSlice.codepoints()`. ## Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Mutability of the underlying string data. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): Origin of the underlying string data. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Iterable`](/mojo/stdlib/iter/Iterable), [`Iterator`](/mojo/stdlib/iter/Iterator), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Element` `comptime Element = Codepoint` The element type yielded by iteration. ### `IteratorType` `comptime IteratorType[iterable_mut: Bool, //, iterable_origin: Origin[iterable_mut]] = CodepointsIter[origin]` The iterator type for this codepoint iterator. #### Parameters * ​iterable\_mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the iterable is mutable. * ​iterable\_origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the iterable. ## Methods ### `__has_next__` `__has_next__(self) -> Bool` Returns True if there are still elements in this iterator. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): A boolean indicating if there are still elements in this iterator. ### `__next__` `__next__(mut self) -> Codepoint` Get the next codepoint in the underlying string slice. This returns the next `Codepoint` encoded in the underlying string, and advances the iterator state. This function will abort if this iterator has been exhausted. **Returns:** `Codepoint`: The next character in the string. ### `__len__` `__len__(self) -> Int` Returns the remaining length of this iterator in `Codepoint`s. The value returned from this method indicates the number of subsequent calls to `next()` that will return a value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): Number of codepoints remaining in this iterator. ### `peek_next` `peek_next(self) -> Optional[Codepoint]` Check what the next codepoint in this iterator is, without advancing the iterator state. Repeated calls to this method will return the same value. # Examples `peek_next()` does not advance the iterator, so repeated calls will return the same value: ```mojo from collections.string import Codepoint from testing import assert_equal var input = StringSlice("123") var iter = input.codepoints() assert_equal(iter.peek_next().value(), Codepoint.ord("1")) assert_equal(iter.peek_next().value(), Codepoint.ord("1")) assert_equal(iter.peek_next().value(), Codepoint.ord("1")) # A call to `next()` return the same value as `peek_next()` had, # but also advance the iterator. assert_equal(iter.next().value(), Codepoint.ord("1")) # Later `peek_next()` calls will return the _new_ next character: assert_equal(iter.peek_next().value(), Codepoint.ord("2")) ``` **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): The next character in the underlying string, or None if the string is empty. ### `next` `next(mut self) -> Optional[Codepoint]` Get the next codepoint in the underlying string slice, or None if the iterator is empty. This returns the next `Codepoint` encoded in the underlying string, and advances the iterator state. **Returns:** [`Optional`](/mojo/stdlib/collections/optional/Optional): A character if the string is not empty, otherwise None.
--- ## StringSlice
`@register_passable(trivial)` `struct StringSlice[mut: Bool, //, origin: Origin[mut]]` A non-owning view to encoded string data. This type is guaranteed to have the same ABI (size, alignment, and field layout) as the `llvm::StringRef` type. See the [`string_slice` module](/mojo/stdlib/collections/string/string_slice/) for more information and examples. Notes: TODO: The underlying string data is guaranteed to be encoded using UTF-8. ## Parameters * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the slice is mutable. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the underlying string data. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Boolable`](/mojo/stdlib/builtin/bool/Boolable), [`ConvertibleToPython`](/mojo/stdlib/python/conversions/ConvertibleToPython), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Defaultable`](/mojo/stdlib/builtin/value/Defaultable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`FloatableRaising`](/mojo/stdlib/builtin/floatable/FloatableRaising), [`Hashable`](/mojo/stdlib/hashlib/hash/Hashable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`IntableRaising`](/mojo/stdlib/builtin/int/IntableRaising), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`PathLike`](/mojo/stdlib/os/pathlike/PathLike), [`Representable`](/mojo/stdlib/builtin/repr/Representable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `Immutable` `comptime Immutable = StringSlice[origin_of((muttoimm origin._mlir_origin))]` The immutable version of the `StringSlice`. ### `Mutable` `comptime Mutable = StringSlice[origin_of((mutcast origin._mlir_origin))]` The mutable version of the `StringSlice`. ## Methods ### `__init__` `__init__() -> Self` Create an empty / zero-length slice. `@implicit` `__init__(lit: StringLiteral[value]) -> StaticString` Construct a new `StringSlice` from a `StringLiteral`. **Args:** * ​lit ([`StringLiteral`](/mojo/stdlib/builtin/string_literal/StringLiteral)): The literal to construct this `StringSlice` from. **Returns:** `StaticString` `__init__(*, unsafe_from_utf8: Span[Byte, origin]) -> Self` Construct a new `StringSlice` from a sequence of UTF-8 encoded bytes. Safety: `unsafe_from_utf8` MUST be valid UTF-8 encoded data. **Args:** * ​unsafe\_from\_utf8 ([`Span`](/mojo/stdlib/memory/span/Span)): A `Span[Byte]` encoded in UTF-8. `__init__(*, unsafe_from_utf8_ptr: UnsafePointer[Byte, origin]) -> Self` Construct a new StringSlice from a `UnsafePointer[Byte]` pointing to null-terminated UTF-8 encoded bytes. Safety: * `unsafe_from_utf8_ptr` MUST point to data that is valid for `origin`. * `unsafe_from_utf8_ptr` MUST be valid UTF-8 encoded data. * `unsafe_from_utf8_ptr` MUST be null terminated. **Args:** * ​unsafe\_from\_utf8\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): An `UnsafePointer[Byte]` of null-terminated bytes encoded in UTF-8. `__init__(out self, *, from_utf8: Span[Byte, origin])` Construct a new `StringSlice` from a buffer containing UTF-8 encoded data. **Args:** * ​from\_utf8 ([`Span`](/mojo/stdlib/memory/span/Span)): A span of bytes containing UTF-8 encoded data. **Raises:** An exception is raised if the provided buffer byte values do not form valid UTF-8 encoded codepoints. `__init__(*, unsafe_from_utf8_ptr: UnsafePointer[c_char, origin]) -> Self` Construct a new StringSlice from a `UnsafePointer[c_char]` pointing to null-terminated UTF-8 encoded bytes. Safety: * `unsafe_from_utf8_ptr` MUST be valid UTF-8 encoded data. * `unsafe_from_utf8_ptr` MUST be null terminated. **Args:** * ​unsafe\_from\_utf8\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): An `UnsafePointer[c_char]` of null-terminated bytes encoded in UTF-8. `__init__(*, ptr: UnsafePointer[Byte, origin], length: Int) -> Self` Construct a `StringSlice` from a pointer to a sequence of UTF-8 encoded bytes and a length. Safety: * `ptr` MUST point to at least `length` bytes of valid UTF-8 encoded data. * `ptr` must point to data that is live for the duration of `origin`. **Args:** * ​ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): A pointer to a sequence of bytes encoded in UTF-8. * ​length ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of bytes of encoded data. `@implicit` `__init__[_origin: ImmutOrigin, //](ref [_origin] value: String) -> StringSlice[_origin]` Construct an immutable StringSlice. **Parameters:** * ​\_origin ([`ImmutOrigin`](/mojo/stdlib/builtin/type_aliases/#immutorigin)): The immutable origin. **Args:** * ​value ([`String`](/mojo/stdlib/collections/string/string/String)): The string value. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice) `__init__[_origin: MutOrigin, //](ref [_origin] value: String) -> StringSlice[_origin]` Construct a mutable StringSlice. **Parameters:** * ​\_origin ([`MutOrigin`](/mojo/stdlib/builtin/type_aliases/#mutorigin)): The mutable origin. **Args:** * ​value ([`String`](/mojo/stdlib/collections/string/string/String)): The string value. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice) ### `__bool__` `__bool__(self) -> Bool` Check if a string slice is non-empty. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if a string slice is non-empty, False otherwise. ### `__getitem__` `__getitem__(self, span: ContiguousSlice) -> Self` Gets the sequence of characters at the specified positions. **Args:** * ​span ([`ContiguousSlice`](/mojo/stdlib/builtin/builtin_slice/ContiguousSlice)): A slice that specifies positions of the new substring. **Returns:** `Self`: A new StringSlice containing the substring at the specified positions. `__getitem__[I: Indexer, //](self, idx: I) -> String` Gets the character at the specified position. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): A type that can be used as an index. **Args:** * ​idx (`I`): The index value. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A new string containing the character at the specified position. ### `__lt__` `__lt__(self, rhs: StringSlice[origin]) -> Bool` Verify if the `StringSlice` bytes are strictly less than the input in overlapping content. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The other `StringSlice` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): If the `StringSlice` bytes are strictly less than the input in overlapping content. `__lt__(self, rhs: String) -> Bool` Define whether this String slice is strictly less than the RHS. **Args:** * ​rhs ([`String`](/mojo/stdlib/collections/string/string/String)): The other `String` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): If the `StringSlice` bytes are strictly less than the input in overlapping content. ### `__le__` `__le__(self, rhs: StringSlice[origin]) -> Bool` Define whether this String slice is less than or equal to the RHS. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The other `StringSlice` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this String slice is less than or equal to the RHS StringSlice. `__le__(self, rhs: String) -> Bool` Define whether this String slice is less than or equal to the RHS. **Args:** * ​rhs ([`String`](/mojo/stdlib/collections/string/string/String)): The other String to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this String slice is less than or equal to the RHS String. ### `__eq__` `__eq__(self, rhs_same: Self) -> Bool` Verify if a `StringSlice` is equal to another `StringSlice` with the same origin. **Args:** * ​rhs\_same (`Self`): The `StringSlice` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): If the `StringSlice` is equal to the input in length and contents. `__eq__(self, rhs: String) -> Bool` Verify if a `StringSlice` is equal to another `String`. **Args:** * ​rhs ([`String`](/mojo/stdlib/collections/string/string/String)): The `StringSlice` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): If the `StringSlice` is equal to the input in length and contents. `__eq__(self, rhs: StringSlice[origin]) -> Bool` Verify if a `StringSlice` is equal to another `StringSlice`. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The `StringSlice` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): If the `StringSlice` is equal to the input in length and contents. ### `__ne__` `__ne__(self, rhs_same: Self) -> Bool` Verify if a `StringSlice` is not equal to another `StringSlice` with the same origin. **Args:** * ​rhs\_same (`Self`): The `StringSlice` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): If the `StringSlice` is not equal to the input in length and contents. `__ne__(self, rhs: StringSlice[origin]) -> Bool` Verify if span is not equal to another `StringSlice`. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The `StringSlice` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): If the `StringSlice` is not equal to the input in length and contents. ### `__gt__` `__gt__(self, rhs: StringSlice[origin]) -> Bool` Define whether this String slice is strictly greater than the RHS. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The other `StringSlice` to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this String slice is strictly greater than the RHS StringSlice. `__gt__(self, rhs: String) -> Bool` Define whether this String slice is strictly greater than the RHS. **Args:** * ​rhs ([`String`](/mojo/stdlib/collections/string/string/String)): The other String to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this String slice is strictly greater than the RHS String. ### `__ge__` `__ge__(self, rhs: String) -> Bool` Define whether this String slice is greater than or equal to the RHS. **Args:** * ​rhs ([`String`](/mojo/stdlib/collections/string/string/String)): The other String to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if this String slice is greater than or equal to the RHS String. ### `__contains__` `__contains__(self, substr: StringSlice[origin]) -> Bool` Returns True if the substring is contained within the current string. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the string contains the substring. ### `__add__` `__add__(self, rhs: StringSlice[origin]) -> String` Returns a string with this value prefixed on another string. **Args:** * ​rhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The right side of the result. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The result string. ### `__mul__` `__mul__(self, n: Int) -> String` Concatenates the string `n` times. **Args:** * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of times to concatenate the string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string concatenated `n` times. ### `__radd__` `__radd__(self, lhs: StringSlice[origin]) -> String` Returns a string with this value appended to another string. **Args:** * ​lhs ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The left side of the result. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The result string. ### `__str__` `__str__(self) -> String` Convert this StringSlice to a String. Notes: This will allocate a new string that copies the string contents from the provided string slice. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A new String. ### `__repr__` `__repr__(self) -> String` Return a Mojo-compatible representation of this string slice. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): Representation of this string slice as a Mojo string literal input form syntax. ### `__len__` `__len__(self) -> Int` Get the string length in bytes. This function returns the number of bytes in the underlying UTF-8 representation of the string. To get the number of Unicode codepoints in a string, use `len(str.codepoints())`. # Examples Query the length of a string, in bytes and Unicode codepoints: ```mojo from testing import assert_equal var s = StringSlice("ನಮಸ್ಕಾರ") assert_equal(len(s), 21) assert_equal(len(s.codepoints()), 7) ``` Strings containing only ASCII characters have the same byte and Unicode codepoint length: ```mojo from testing import assert_equal var s = StringSlice("abc") assert_equal(len(s), 3) assert_equal(len(s.codepoints()), 3) ``` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The string length in bytes. ### `write_to` `write_to(self, mut writer: T)` Formats this string slice to the provided `Writer`. **Args:** * ​writer (`T`): The object to write to. ### `__hash__` `__hash__[H: Hasher](self, mut hasher: H)` Updates hasher with the underlying bytes. **Parameters:** * ​H ([`Hasher`](/mojo/stdlib/hashlib/hasher/Hasher)): The hasher type. **Args:** * ​hasher (`H`): The hasher instance. ### `__fspath__` `__fspath__(self) -> String` Return the file system path representation of this string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The file system path representation as a string. ### `to_python_object` `to_python_object(var self) -> PythonObject` Convert this value to a PythonObject. **Returns:** [`PythonObject`](/mojo/stdlib/python/python_object/PythonObject): A PythonObject representing the value. **Raises:** If the operation fails. ### `__iter__` `__iter__(self) -> CodepointSliceIter[origin]` Iterate over the string, returning immutable references. **Returns:** `CodepointSliceIter`: An iterator of references to the string elements. ### `__reversed__` `__reversed__(self) -> CodepointSliceIter[origin, False]` Iterate backwards over the string, returning immutable references. **Returns:** `CodepointSliceIter`: A reversed iterator of references to the string elements. ### `__int__` `__int__(self) -> Int` Parses the given string as a base-10 integer and returns that value. If the string cannot be parsed as an int, an error is raised. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer value that represents the string, or otherwise raises. **Raises:** If the operation fails. ### `__float__` `__float__(self) -> Float64` Parses the string as a float point number and returns that value. If the string cannot be parsed as a float, an error is raised. **Returns:** [`Float64`](/mojo/stdlib/builtin/simd/#float64): A float value that represents the string, or otherwise raises. **Raises:** If the operation fails. ### `__merge_with__` `__merge_with__[other_type: AnyStruct[StringSlice[origin]]](self) -> StringSlice[origin_of((mutcast origin._mlir_origin), (mutcast origin._mlir_origin))]` Returns a string slice with merged origins. **Parameters:** * ​other\_type (`AnyStruct`): The type of the origin to merge with. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A StringSlice merged with the other origin. ### `get_immutable` `get_immutable(self) -> StringSlice[origin].Immutable` Return an immutable version of this Span. **Returns:** [`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): An immutable version of the same Span. ### `replace` `replace(self, old: StringSlice[origin], new: StringSlice[origin]) -> String` Return a copy of the string with all occurrences of substring `old` if replaced by `new`. **Args:** * ​old ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to replace. * ​new ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to replace with. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The string where all occurrences of `old` are replaced with `new`. ### `strip` `strip(self, chars: StringSlice[origin]) -> Self` Return a copy of the string with leading and trailing characters removed. Example: ```mojo print("himojohi".strip("hi")) # "mojo" ``` **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** `Self`: A copy of the string with no leading or trailing characters. `strip(self) -> Self` Return a copy of the string with leading and trailing whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. Example: ```mojo print(" mojo ".strip()) # "mojo" ``` **Returns:** `Self`: A copy of the string with no leading or trailing whitespaces. ### `rstrip` `rstrip(self, chars: StringSlice[origin]) -> Self` Return a copy of the string with trailing characters removed. Example: ```mojo print("mojohi".strip("hi")) # "mojo" ``` **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** `Self`: A copy of the string with no trailing characters. `rstrip(self) -> Self` Return a copy of the string with trailing whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. Example: ```mojo print("mojo ".strip()) # "mojo" ``` **Returns:** `Self`: A copy of the string with no trailing whitespaces. ### `lstrip` `lstrip(self, chars: StringSlice[origin]) -> Self` Return a copy of the string with leading characters removed. Example: ```mojo print("himojo".strip("hi")) # "mojo" ``` **Args:** * ​chars ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): A set of characters to be removed. Defaults to whitespace. **Returns:** `Self`: A copy of the string with no leading characters. `lstrip(self) -> Self` Return a copy of the string with leading whitespaces removed. This only takes ASCII whitespace into account: `" \t\n\v\f\r\x1c\x1d\x1e"`. Example: ```mojo print(" mojo".strip()) # "mojo" ``` **Returns:** `Self`: A copy of the string with no leading whitespaces. ### `codepoints` `codepoints(self) -> CodepointsIter[origin]` Returns an iterator over the `Codepoint`s encoded in this string slice. # Examples Print the characters in a string: ```mojo from testing import assert_equal var s = StringSlice("abc") var iter = s.codepoints() assert_equal(iter.__next__(), Codepoint.ord("a")) assert_equal(iter.__next__(), Codepoint.ord("b")) assert_equal(iter.__next__(), Codepoint.ord("c")) assert_equal(iter.__has_next__(), False) ``` `codepoints()` iterates over Unicode codepoints, and supports multibyte codepoints: ```mojo from testing import assert_equal # A visual character composed of a combining sequence of 2 codepoints. var s = StringSlice("á") assert_equal(s.byte_length(), 3) var iter = s.codepoints() assert_equal(iter.__next__(), Codepoint.ord("a")) # U+0301 Combining Acute Accent assert_equal(iter.__next__().to_u32(), 0x0301) assert_equal(iter.__has_next__(), False) ``` **Returns:** `CodepointsIter`: An iterator type that returns successive `Codepoint` values stored in this string slice. ### `codepoint_slices` `codepoint_slices(self) -> CodepointSliceIter[origin]` Iterate over the string, returning immutable references. **Returns:** `CodepointSliceIter`: An iterator of references to the string elements. ### `as_bytes` `as_bytes(self) -> Span[Byte, origin]` Get the sequence of encoded bytes of the underlying string. **Returns:** [`Span`](/mojo/stdlib/memory/span/Span): A slice containing the underlying sequence of encoded bytes. ### `unsafe_ptr` `unsafe_ptr(self) -> UnsafePointer[Byte, origin]` Gets a pointer to the first element of this string slice. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): A pointer pointing at the first element of this string slice. ### `byte_length` `byte_length(self) -> Int` Get the length of this string slice in bytes. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The length of this string slice in bytes. ### `char_length` `char_length(self) -> UInt` Returns the length in Unicode codepoints. This returns the number of `Codepoint` codepoint values encoded in the UTF-8 representation of this string. Note: To get the length in bytes, use `StringSlice.byte_length()`. # Examples Query the length of a string, in bytes and Unicode codepoints: ```mojo from testing import assert_equal var s = StringSlice("ನಮಸ್ಕಾರ") assert_equal(s.char_length(), 7) assert_equal(len(s), 21) ``` Strings containing only ASCII characters have the same byte and Unicode codepoint length: ```mojo from testing import assert_equal var s = StringSlice("abc") assert_equal(s.char_length(), 3) assert_equal(len(s), 3) ``` The character length of a string with visual combining characters is the length in Unicode codepoints, not grapheme clusters: ```mojo from testing import assert_equal var s = StringSlice("á") assert_equal(s.char_length(), 2) assert_equal(s.byte_length(), 3) ``` **Returns:** [`UInt`](/mojo/stdlib/builtin/uint/UInt): The length in Unicode codepoints. ### `is_codepoint_boundary` `is_codepoint_boundary(self, index: UInt) -> Bool` Returns True if `index` is the position of the first byte in a UTF-8 codepoint sequence, or is at the end of the string. A byte position is considered a codepoint boundary if a valid subslice of the string would end (noninclusive) at `index`. Positions `0` and `len(self)` are considered to be codepoint boundaries. Positions beyond the length of the string slice will return False. Examples: Check if particular byte positions are codepoint boundaries: ```mojo from testing import assert_equal, assert_true, assert_false var abc = StringSlice("abc") assert_equal(len(abc), 3) assert_true(abc.is_codepoint_boundary(0)) assert_true(abc.is_codepoint_boundary(1)) assert_true(abc.is_codepoint_boundary(2)) assert_true(abc.is_codepoint_boundary(3)) ``` Only the index of the first byte in a multi-byte codepoint sequence is considered a codepoint boundary: ```mojo var thumb = StringSlice("👍") assert_equal(len(thumb), 4) assert_true(thumb.is_codepoint_boundary(0)) assert_false(thumb.is_codepoint_boundary(1)) assert_false(thumb.is_codepoint_boundary(2)) assert_false(thumb.is_codepoint_boundary(3)) ``` Visualization showing which bytes are considered codepoint boundaries, within a piece of text that includes codepoints whose UTF-8 representation requires, respectively, 1, 2, 3, and 4-bytes. The codepoint boundary byte indices are indicated by a vertical arrow (↑). For example, this diagram shows that a slice of bytes formed by the half-open range starting at byte 3 and extending up to but not including byte 6 (`[3, 6)`) is a valid UTF-8 sequence. ```text ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ a©➇𝄞 ┃ String ┣━━┳━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━┫ ┃97┃ 169 ┃ 10119 ┃ 119070 ┃ Unicode Codepoints ┣━━╋━━━┳━━━╋━━━┳━━━┳━━━╋━━━┳━━━┳━━━┳━━━┫ ┃97┃194┃169┃226┃158┃135┃240┃157┃132┃158┃ UTF-8 Bytes ┗━━┻━━━┻━━━┻━━━┻━━━┻━━━┻━━━┻━━━┻━━━┻━━━┛ 0 1 2 3 4 5 6 7 8 9 10 ↑ ↑ ↑ ↑ ↑ ``` The following program verifies the above diagram: ```mojo from testing import assert_true, assert_false var text = StringSlice("a©➇𝄞") assert_true(text.is_codepoint_boundary(0)) assert_true(text.is_codepoint_boundary(1)) assert_false(text.is_codepoint_boundary(2)) assert_true(text.is_codepoint_boundary(3)) assert_false(text.is_codepoint_boundary(4)) assert_false(text.is_codepoint_boundary(5)) assert_true(text.is_codepoint_boundary(6)) assert_false(text.is_codepoint_boundary(7)) assert_false(text.is_codepoint_boundary(8)) assert_false(text.is_codepoint_boundary(9)) assert_true(text.is_codepoint_boundary(10)) ``` **Args:** * ​index ([`UInt`](/mojo/stdlib/builtin/uint/UInt)): An index into the underlying byte representation of the string. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): A boolean indicating if `index` gives the position of the first byte in a UTF-8 codepoint sequence, or is at the end of the string. ### `startswith` `startswith(self, prefix: StringSlice[origin], start: Int = 0, end: Int = -1) -> Bool` Verify if the `StringSlice` starts with the specified prefix between start and end positions. The `start` and `end` positions must be offsets given in bytes, and must be codepoint boundaries. **Args:** * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix to check. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The start offset in bytes from which to check. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end offset in bytes from which to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `self[start:end]` is prefixed by the input prefix. ### `endswith` `endswith(self, suffix: StringSlice[origin], start: Int = 0, end: Int = -1) -> Bool` Verify if the `StringSlice` end with the specified suffix between start and end positions. The `start` and `end` positions must be offsets given in bytes, and must be codepoint boundaries. **Args:** * ​suffix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The suffix to check. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The start offset in bytes from which to check. * ​end ([`Int`](/mojo/stdlib/builtin/int/Int)): The end offset in bytes from which to check. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the `self[start:end]` is suffixed by the input suffix. ### `removeprefix` `removeprefix(self, prefix: StringSlice[origin], /) -> Self` Returns a new string with the prefix removed if it was present. Examples: ```mojo print(StringSlice('TestHook').removeprefix('Test')) # 'Hook' print(StringSlice('BaseTestCase').removeprefix('Test')) # 'BaseTestCase' ``` **Args:** * ​prefix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The prefix to remove from the string. **Returns:** `Self`: `string[len(prefix):]` if the string starts with the prefix string, or a copy of the original string otherwise. ### `removesuffix` `removesuffix(self, suffix: StringSlice[origin], /) -> Self` Returns a new string with the suffix removed if it was present. Examples: ```mojo print(StringSlice('TestHook').removesuffix('Hook')) # 'Test' print(StringSlice('BaseTestCase').removesuffix('Test')) # 'BaseTestCase' ``` **Args:** * ​suffix ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The suffix to remove from the string. **Returns:** `Self`: `string[:-len(suffix)]` if the string ends with the suffix string, or a copy of the original string otherwise. ### `format` `format[*Ts: Stringable & Representable](self, *args: *Ts) -> String` Produce a formatted string using the current string as a template. The template, or "format string" can contain literal text and/or replacement fields delimited with curly braces (`{}`). Returns a copy of the format string with the replacement fields replaced with string representations of the `args` arguments. For more information, see the discussion in the [`format` module](/mojo/stdlib/collections/string/format/). Examples: ```mojo # Manual indexing: print(StringSlice("{0} {1} {0}").format("Mojo", 1.125)) # Mojo 1.125 Mojo # Automatic indexing: print(StringSlice("{} {}").format(True, "hello world")) # True hello world ``` **Parameters:** * ​\*Ts ([`Stringable`](/mojo/stdlib/builtin/str/Stringable) & [`Representable`](/mojo/stdlib/builtin/repr/Representable)): The types of substitution values that implement `Representable` and `Stringable` (to be changed and made more flexible). **Args:** * ​\*args (`*Ts`): The substitution values. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The template with the given values substituted. **Raises:** If the operation fails. ### `find` `find(self, substr: StringSlice[origin], start: Int = 0) -> Int` Finds the offset in bytes of the first occurrence of `substr` starting at `start`. If not found, returns `-1`. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to find. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset in bytes from which to find. Must be a codepoint boundary. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The offset in bytes of `substr` relative to the beginning of the string. ### `rfind` `rfind(self, substr: StringSlice[origin], start: Int = 0) -> Int` Finds the offset in bytes of the last occurrence of `substr` starting at `start`. If not found, returns `-1`. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to find. * ​start ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset in bytes from which to find. Must be a valid codepoint boundary. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The offset in bytes of `substr` relative to the beginning of the string. ### `isspace` `isspace[single_character: Bool = False](self) -> Bool` Determines whether every character in the given StringSlice is a python whitespace String. This corresponds to Python's [universal separators](https://docs.python.org/3/library/stdtypes.html#str.splitlines): `" \t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029"`. Example: Check if a string contains only whitespace: ```mojo # An empty string is not considered to contain only whitespace chars: assert_false(StringSlice("").isspace()) # ASCII space characters assert_true(StringSlice(" ").isspace()) assert_true(StringSlice(" ").isspace()) # Contains non-space characters assert_false(StringSlice(" abc ").isspace()) ``` **Parameters:** * ​single\_character ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to evaluate the `StringSlice` as a single unicode character (avoids overhead when already iterating). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the whole StringSlice is made up of whitespace characters listed above, otherwise False. ### `split` `split(self, sep: StringSlice[origin]) -> List[StringSlice[origin].Immutable]` Split the string by a separator. Examples: ```mojo # Splitting a space _ = StringSlice("hello world").split(" ") # ["hello", "world"] # Splitting adjacent separators _ = StringSlice("hello,,world").split(",") # ["hello", "", "world"] # Splitting with starting or ending separators _ = StringSlice(",1,2,3,").split(",") # ['', '1', '2', '3', ''] # Splitting with an empty separator _ = StringSlice("123").split("") # ['', '1', '2', '3', ''] ``` **Args:** * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to split on. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: StringSlice[origin], maxsplit: Int) -> List[StringSlice[origin].Immutable]` Split the string by a separator. Examples: ```mojo # Splitting with maxsplit _ = StringSlice("1,2,3").split(",", maxsplit=1) # ['1', '2,3'] # Splitting with starting or ending separators _ = StringSlice(",1,2,3,").split(",", maxsplit=1) # ['', '1,2,3,'] # Splitting with an empty separator _ = StringSlice("123").split("", maxsplit=1) # ['', '123'] ``` **Args:** * ​sep ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string to split on. * ​maxsplit ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum amount of items to split from String. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: NoneType = None) -> List[StringSlice[origin].Immutable]` Split the string by every Whitespace separator. Examples: ```mojo # Splitting an empty string or filled with whitespaces _ = StringSlice(" ").split() # [] _ = StringSlice("").split() # [] # Splitting a string with leading, trailing, and middle whitespaces _ = StringSlice(" hello world ").split() # ["hello", "world"] # Splitting adjacent universal newlines: _ = StringSlice( "hello \t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029world" ).split() # ["hello", "world"] ``` **Args:** * ​sep ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): None. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. `split(self, sep: NoneType = None, *, maxsplit: Int) -> List[StringSlice[origin].Immutable]` Split the string by every Whitespace separator. Examples: ```mojo # Splitting with maxsplit _ = StringSlice("1 2 3").split(maxsplit=1) # ['1', '2 3'] ``` **Args:** * ​sep ([`NoneType`](/mojo/stdlib/builtin/none/NoneType)): None. * ​maxsplit ([`Int`](/mojo/stdlib/builtin/int/Int)): The maximum amount of items to split from String. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by the separator. ### `isnewline` `isnewline[single_character: Bool = False](self) -> Bool` Determines whether every character in the given StringSlice is a python newline character. This corresponds to Python's [universal newlines:](https://docs.python.org/3/library/stdtypes.html#str.splitlines) `"\r\n"` and `"\t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029"`. **Parameters:** * ​single\_character ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to evaluate the stringslice as a single unicode character (avoids overhead when already iterating). **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the whole StringSlice is made up of whitespace characters listed above, otherwise False. ### `splitlines` `splitlines(self, keepends: Bool = False) -> List[StringSlice[origin].Immutable]` Split the string at line boundaries. This corresponds to Python's [universal newlines:](https://docs.python.org/3/library/stdtypes.html#str.splitlines) `"\r\n"` and `"\t\n\v\f\r\x1c\x1d\x1e\x85\u2028\u2029"`. **Args:** * ​keepends ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, line breaks are kept in the resulting strings. **Returns:** [`List`](/mojo/stdlib/collections/list/List): A List of Strings containing the input split by line boundaries. ### `count` `count(self, substr: StringSlice[origin]) -> Int` Return the number of non-overlapping occurrences of substring `substr` in the string. If sub is empty, returns the number of empty strings between characters which is the length of the string plus one. **Args:** * ​substr ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The substring to count. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of occurrences of `substr`. ### `is_ascii_digit` `is_ascii_digit(self) -> Bool` A string is a digit string if all characters in the string are digits and there is at least one character in the string. Note that this currently only works with ASCII strings. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all characters are digits and it's not empty else False. ### `isupper` `isupper(self) -> Bool` Returns True if all cased characters in the string are uppercase and there is at least one cased character. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all cased characters in the string are uppercase and there is at least one cased character, False otherwise. ### `islower` `islower(self) -> Bool` Returns True if all cased characters in the string are lowercase and there is at least one cased character. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all cased characters in the string are lowercase and there is at least one cased character, False otherwise. ### `lower` `lower(self) -> String` Returns a copy of the string with all cased characters converted to lowercase. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A new string where cased letters have been converted to lowercase. ### `upper` `upper(self) -> String` Returns a copy of the string with all cased characters converted to uppercase. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A new string where cased letters have been converted to uppercase. ### `is_ascii_printable` `is_ascii_printable(self) -> Bool` Returns True if all characters in the string are ASCII printable. Note that this currently only works with ASCII strings. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all characters are printable else False. ### `rjust` `rjust(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> String` Returns the string slice right justified in a string of specified width. Pads the string slice on the left with the specified fill character so that the total length of the resulting string equals `width`. If the original string slice is already longer than or equal to `width`, returns the string slice unchanged (as a `String`). Examples: ```mojo var s = StringSlice("hello") print(s.rjust(10)) # " hello" print(s.rjust(10, "*")) # "*****hello" print(s.rjust(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A right-justified string of length `width`, or the original string slice (as a `String`) if its length is already greater than or equal to `width`. ### `ljust` `ljust(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> String` Returns the string slice left justified in a string of specified width. Pads the string slice on the right with the specified fill character so that the total length of the resulting string equals `width`. If the original string slice is already longer than or equal to `width`, returns the string slice unchanged (as a `String`). Examples: ```mojo var s = StringSlice("hello") print(s.ljust(10)) # "hello " print(s.ljust(10, "*")) # "hello*****" print(s.ljust(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A left-justified string of length `width`, or the original string slice (as a `String`) if its length is already greater than or equal to `width`. ### `center` `center(self, width: Int, fillchar: StringSlice[StaticConstantOrigin] = " ") -> String` Returns the string slice center justified in a string of specified width. Pads the string slice on both sides with the specified fill character so that the total length of the resulting string equals `width`. If the padding needed is odd, the extra character goes on the right side. If the original string slice is already longer than or equal to `width`, returns the string slice unchanged (as a `String`). Examples: ```mojo var s = StringSlice("hello") print(s.center(10)) # " hello " print(s.center(11, "*")) # "***hello***" print(s.center(3)) # "hello" (no padding) ``` **Args:** * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The total width (in bytes) of the resulting string. This is not the amount of padding, but the final length of the returned string. * ​fillchar ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The padding character to use (defaults to space). Must be a single-byte character. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A center-justified string of length `width`, or the original string slice (as a `String`) if its length is already greater than or equal to `width`. ### `join` `join[T: Copyable & Writable, //](self, elems: Span[T, origin]) -> String` Joins string elements using the current string as a delimiter. Notes: * Defaults to writing directly to the string if the bytes fit in an inline `String`, otherwise will process it by chunks. **Parameters:** * ​T ([`Copyable`](/mojo/stdlib/builtin/value/Copyable) & [`Writable`](/mojo/stdlib/io/write/Writable)): The type of the elements, must implement the `Copyable`, and `Writable` traits. **Args:** * ​elems ([`Span`](/mojo/stdlib/memory/span/Span)): The input values. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The joined string.
--- ## get_static_string
`get_static_string[string: StringSlice[StaticConstantOrigin], *extra: StringSlice[StaticConstantOrigin]]() -> StaticString` Form a StaticString from compile-time StringSlice values. This guarantees that the returned string is compile-time constant in static memory. It also guarantees that there is a 'nul' zero byte at the end, which is not included in the returned range. **Parameters:** * ​string ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The first StringSlice value. * ​\*extra ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Additional StringSlice values to concatenate. **Returns:** `StaticString`: The string value as a StaticString.
--- ## string_slice
The `StringSlice` type implementation for efficient string operations. This module provides the `StringSlice` type, which is a lightweight view into string data that enables zero-copy string operations. `StringSlice` is designed for high-performance string manipulation while maintaining memory safety and UTF-8 awareness. The `StringSlice` type is particularly useful for: * High-performance string operations without copying. * Efficient string parsing and tokenization. `StaticString` is an alias for an immutable constant `StringSlice`. `StringSlice` and `StaticString` are in the prelude, so they are automatically imported into every Mojo program. Example: ```mojo # Create a string slice var text = StringSlice("Hello, 世界") # Zero-copy slicing var hello = text[0:5] # Hello # Unicode-aware operations var world = text[7:13] # "世界" # String comparison if text.startswith("Hello"): print("Found greeting") # String formatting var format_string = StaticString("{}: {}") print(format_string.format("bats", 6)) # bats: 6 ``` ## `comptime` values ### `StaticString` `comptime StaticString = StringSlice[StaticConstantOrigin]` An immutable static string slice. ## Structs * [​`CodepointsIter`](/mojo/stdlib/collections/string/string_slice/CodepointsIter): Iterator over the `Codepoint`s in a string slice, constructed by `StringSlice.codepoints()`. * [​`CodepointSliceIter`](/mojo/stdlib/collections/string/string_slice/CodepointSliceIter): Iterator for `StringSlice` over substring slices containing a single Unicode codepoint. * [​`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice): A non-owning view to encoded string data. ## Functions * [​`get_static_string`](/mojo/stdlib/collections/string/string_slice/get_static_string): Form a StaticString from compile-time StringSlice values. This guarantees that the returned string is compile-time constant in static memory. It also guarantees that there is a 'nul' zero byte at the end, which is not included in the returned range.
--- ## CompiledFunctionInfo
`@register_passable(trivial)` `struct CompiledFunctionInfo[func_type: AnyTrivialRegType, func: func_type, target: __mlir_type.`!kgen.target`]` Contains compilation information and results for a function. Stores assembly/IR code, function metadata, and error information from compiling a function. Attributes: populate: Function to populate captures ## Parameters * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the function being compiled. * ​func (`func_type`): The function being compiled. * ​target (`__mlir_type.`!kgen.target\`\`): The target architecture to compile for. ## Fields * ​asm (`StaticString`): Generated assembly/IR code from the compilation process. * ​function\_name (`StaticString`): Mangled name of the compiled function, used for symbol resolution. * ​module\_name (`StaticString`): Name of the module containing the compiled function. * ​num\_captures (`Int`): Number of variables captured by the function closure. * ​capture\_sizes (`UnsafePointer[UInt64, origin_of()]`): Pointer to the sizes of the variables captured by the function closure. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `populate` `comptime populate = rebind[fn(UnsafePointer[NoneType, MutAnyOrigin]) capturing -> None](#kgen.compile_offload_closure : !kgen.param>)` Function pointer to populate captured variables in the function closure. ## Methods ### `__contains__` `__contains__(self, content: String) -> Bool` Checks if content exists in the assembly/IR. **Args:** * ​content ([`String`](/mojo/stdlib/collections/string/string/String)): String to search for. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if content is found, False otherwise. ### `write_to` `write_to(self, mut writer: T)` Writes the assembly/IR to a writer. **Args:** * ​writer (`T`): Writer object to write the assembly to. ### `__str__` `__str__(self) -> String` Converts the assembly/IR to a string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The assembly/IR as a string. ### `write_text` `write_text[path_like: PathLike](self, path: path_like)` Writes the assembly/IR to a file. **Parameters:** * ​path\_like ([`PathLike`](/mojo/stdlib/os/pathlike/PathLike)): Type that implements the `PathLike` interface for file path representation. **Args:** * ​path (`path_like`): Path to write the file to. **Raises:** If file writing operations fail.
--- ## compile_info
`compile_info[func_type: AnyTrivialRegType, //, func: func_type, /, *, emission_kind: StringSlice[StaticConstantOrigin] = "asm", target: __mlir_type.`!kgen.target` = _current_target(), compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[target]()]() -> CompiledFunctionInfo[func_type, func, target]` Compiles a function and returns detailed compilation information. This function takes a Mojo function and compiles it, providing access to the generated assembly code, linkage information, and other compilation artifacts. It can be used for inspection, debugging, and low-level optimization. Example: ```mojo from compile import compile_info fn my_func(x: Int) -> Int: return x info = compile_info[my_func]() print(info) # Print assembly ``` Note: The compilation is always performed, even if the function is not used. For performance-critical code, consider caching the compilation results. **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the function to compile. Must be a trivially-copyable register type. * ​func (`func_type`): The function to compile. Must match the specified func\_type. * ​emission\_kind ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The desired output format. Valid options are: * "asm": Assembly code (default). * "llvm": Unoptimized LLVM IR. * "llvm-opt": Optimized LLVM IR. * "object": Object code. * ​target (`__mlir_type.`!kgen.target\`\`): The target architecture to compile for. Defaults to current architecture. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Additional compiler flags and options as a string. **Returns:** `CompiledFunctionInfo`: A `CompiledFunctionInfo` struct containing: * asm: The generated code in the requested format * linkage\_name: The mangled function name for linking * module\_hash: A unique hash of the compiled module * num\_captures: Number of captured variables * error: Any error message (empty if successful) * failed: Boolean indicating if compilation failed
--- ## compile
Provides utilities for compiling and inspecting Mojo code. This module contains functionality for compiling Mojo functions and examining their assembly, LLVM IR, or object code output. It is particularly useful for kernel engineers who want to inspect the low-level implementation details of specific functions without dealing with entire files or manual invocation of compilation tools. Key features: * Compile individual functions to assembly, LLVM IR, or object code * Get linkage names and module information * Inspect number of captures and other function metadata * Write compilation output to files * Control compilation options and targets Example: ```mojo from compile import compile_info fn my_func(x: Int) -> Int: return x # Get assembly for the function info = compile_info[my_func]() print(info) ``` ## Structs * [​`CompiledFunctionInfo`](/mojo/stdlib/compile/compile/CompiledFunctionInfo): Contains compilation information and results for a function. ## Functions * [​`compile_info`](/mojo/stdlib/compile/compile/compile_info): Compiles a function and returns detailed compilation information.
--- ## compile (Compile)
Provides utilities for compiling and inspecting Mojo code at runtime. This module exposes functionality for compiling individual Mojo functions and examining their low-level implementation details. It is particularly useful for: * Inspecting assembly, LLVM IR, or object code output * Getting linkage names and module information * Examining function metadata like captures * Writing compilation output to files * Controlling compilation options and targets Example: ```mojo from compile import compile_info fn my_func(): print("Hello") # Get assembly for the function info = compile_info[my_func]() print(info.asm) ``` ## Modules * [​`compile`](/mojo/stdlib/compile/compile/): Provides utilities for compiling and inspecting Mojo code. * [​`reflection`](/mojo/stdlib/compile/reflection/):
--- ## get_function_name
`get_function_name[func_type: AnyType, //, func: func_type]() -> StaticString` Returns `func`'s name as declared in the source code. The returned name does not include any information about the function's parameters, arguments, or return type, just the name as declared in the source code. **Parameters:** * ​func\_type ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Type of func. * ​func (`func_type`): A mojo function. **Returns:** `StaticString`: The function's name as declared in the source code.
--- ## get_linkage_name
`get_linkage_name[func_type: AnyType, //, func: func_type, *, target: __mlir_type.`!kgen.target` = _current_target()]() -> StaticString` Returns `func`'s symbol name. **Parameters:** * ​func\_type ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Type of func. * ​func (`func_type`): A mojo function. * ​target (`__mlir_type.`!kgen.target\`\`): The compilation target, defaults to the current target. **Returns:** `StaticString`: Symbol name.
--- ## get_type_name
`get_type_name[type_type: AnyTrivialRegType, //, type: type_type, *, qualified_builtins: Bool = False]() -> StaticString` Returns the struct name of the given type parameter. **Parameters:** * ​type\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of type. * ​type (`type_type`): A mojo type. * ​qualified\_builtins ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to print fully qualified builtin type names (e.g. `stdlib.builtin.int.Int`) or shorten them (e.g. `Int`). **Returns:** `StaticString`: Type name.
--- ## reflection
## Functions * [​`get_function_name`](/mojo/stdlib/compile/reflection/get_function_name): Returns `func`'s name as declared in the source code. * [​`get_linkage_name`](/mojo/stdlib/compile/reflection/get_linkage_name): Returns `func`'s symbol name. * [​`get_type_name`](/mojo/stdlib/compile/reflection/get_type_name): Returns the struct name of the given type parameter.
--- ## ComplexSIMD
`@register_passable(trivial)` `struct ComplexSIMD[dtype: DType, size: Int]` Represents a complex SIMD value. The class provides basic methods for manipulating complex values. ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): DType of the value. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): SIMD width of the value. ## Fields * ​re (`ComplexSIMD[dtype, size].element_type`): The real part of the complex SIMD value. * ​im (`ComplexSIMD[dtype, size].element_type`): The imaginary part of the complex SIMD value. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable), [`_Expable`](/mojo/stdlib/math/math/_Expable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `element_type` `comptime element_type = SIMD[dtype, size]` The SIMD type used for real and imaginary parts. ### `type` `comptime type = dtype` The data type of the complex components. ## Methods ### `__init__` `__init__(re: SIMD[dtype, size], im: SIMD[dtype, size] = 0) -> Self` Initializes a complex SIMD value. **Args:** * ​re ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The real part of the complex value. * ​im ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The imaginary part of the complex value. `__init__(*, from_interleaved: SIMD[dtype, (2 * size)]) -> Self` Initializes a complex SIMD value. **Args:** * ​from\_interleaved ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): An interleaved vector of complex values e.g. `[0, 1, 1, 0]` where the pattern is `[re0, im0, re1, im1]`. `__init__(*, from_deinterleaved: SIMD[dtype, (2 * size)]) -> Self` Initializes a complex SIMD value. **Args:** * ​from\_deinterleaved ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): A deinterleaved vector of complex values e.g. `[0, 1, 1, 0]` where the pattern is `[re0, re1, im0, im1]`. ### `__neg__` `__neg__(self) -> Self` Negates the complex value. **Returns:** `Self`: The negative of the complex value. ### `__eq__` `__eq__(self, rhs: Self) -> Bool` Compares two ComplexSIMD for equality. **Args:** * ​rhs (`Self`): The ComplexSIMD to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if all elements of the ComplexSIMD are equal, False otherwise. ### `__add__` `__add__(self, rhs: Self) -> Self` Adds two complex values. **Args:** * ​rhs (`Self`): Complex value to add. **Returns:** `Self`: A sum of this and RHS complex values. ### `__sub__` `__sub__(self, rhs: Self) -> Self` Subtracts two complex values. **Args:** * ​rhs (`Self`): Complex value to subtract. **Returns:** `Self`: A difference of this and RHS complex values. ### `__mul__` `__mul__(self, rhs: Self) -> Self` Multiplies two complex values. **Args:** * ​rhs (`Self`): Complex value to multiply with. **Returns:** `Self`: A product of this and RHS complex values. `__mul__(self, rhs: Scalar[dtype]) -> Self` Multiplies a complex value to a scalar. **Args:** * ​rhs ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Scalar value to multiply with. **Returns:** `Self`: A product of self and rhs. ### `__truediv__` `__truediv__(self, rhs: Self) -> Self` Divides two complex values. **Args:** * ​rhs (`Self`): Complex value to divide by. **Returns:** `Self`: A quotient of this and RHS complex values. ### `__rmul__` `__rmul__(self, lhs: Scalar[dtype]) -> Self` Multiplies a complex value to a scalar. **Args:** * ​lhs ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Scalar value to multiply with. **Returns:** `Self`: A product of self and lhs. ### `__imul__` `__imul__(mut self, rhs: Self)` Multiplies two complex values inplace. **Args:** * ​rhs (`Self`): Complex value to multiply with. `__imul__(mut self, rhs: Scalar[dtype])` Multiplies a complex value to a scalar inplace. **Args:** * ​rhs ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Scalar value to multiply with. ### `__str__` `__str__(self) -> String` Get the complex as a string. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string representation. ### `write_to` `write_to(self, mut writer: T)` Formats this complex value to the provided Writer. **Args:** * ​writer (`T`): The object to write to. ### `__abs__` `__abs__(self) -> SIMD[dtype, size]` Returns the magnitude of the complex value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Value of `sqrt(re*re + im*im)`. ### `conj` `conj(self) -> Self` Return the complex conjugate of self. **Returns:** `Self`: The complex conjugate of self. ### `norm` `norm(self) -> SIMD[dtype, size]` Returns the magnitude of the complex value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Value of `sqrt(re*re + im*im)`. ### `squared_norm` `squared_norm(self) -> SIMD[dtype, size]` Returns the squared magnitude of the complex value. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Value of `re*re + im*im`. ### `fma` `fma(self, b: Self, c: Self) -> Self` Computes FMA operation. Compute fused multiple-add with two other complex values: `result = self * b + c` **Args:** * ​b (`Self`): Multiplier complex value. * ​c (`Self`): Complex value to add. **Returns:** `Self`: Computed `Self * B + C` complex value. ### `squared_add` `squared_add(self, c: Self) -> Self` Computes Square-Add operation. Compute `Self * Self + C`. **Args:** * ​c (`Self`): Complex value to add. **Returns:** `Self`: Computed `Self * Self + C` complex value. ### `__exp__` `__exp__(self) -> Self` Computes the exponential of the complex value. **Returns:** `Self`: The exponential of the complex value.
--- ## abs (Complex)
`abs(x: ComplexSIMD[dtype, size]) -> SIMD[dtype, size]` Performs elementwise abs (norm) on each element of the complex value. **Args:** * ​x ([`ComplexSIMD`](/mojo/stdlib/complex/complex/ComplexSIMD)): The complex vector to perform absolute value on. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The elementwise abs of x.
--- ## complex
Implements the Complex dtype. You can import these APIs from the `complex` package. For example: ```mojo from complex import ComplexSIMD ``` ## `comptime` values ### `ComplexFloat32` `comptime ComplexFloat32 = ComplexSIMD[DType.float32, 1]` A complex number with 32-bit floating point components. ### `ComplexFloat64` `comptime ComplexFloat64 = ComplexSIMD[DType.float64, 1]` A complex number with 64-bit floating point components. ### `ComplexScalar` `comptime ComplexScalar = ComplexSIMD[?, 1]` Represents a scalar complex value. ## Structs * [​`ComplexSIMD`](/mojo/stdlib/complex/complex/ComplexSIMD): Represents a complex SIMD value. ## Functions * [​`abs`](/mojo/stdlib/complex/complex/abs): Performs elementwise abs (norm) on each element of the complex value.
--- ## complex (Complex)
Provides types and functions for working with complex numbers. ## Modules * [​`complex`](/mojo/stdlib/complex/complex/): Implements the Complex dtype.
--- ## doc_private
`doc_private()` Indicate that the decorated declaration is private from the viewpoint of documentation generation. This decorator allows for hiding the documentation for a declaration during generation. This is often used to hide `__init__`, and other special methods, that are not intended to be part of a library's documentation. For example: ```mojo struct Foo: @doc_private fn __init__(out self): "This should not be called directly, use `Foo.create` instead." return @staticmethod fn create() -> Self: return Self() ```
--- ## documentation
Provides decorators and utilities for interacting with Mojo documentation generation and validation. These are Mojo built-ins, so you don't need to import them. ## Functions * [​`doc_private`](/mojo/stdlib/documentation/documentation/doc_private): Indicate that the decorated declaration is private from the viewpoint of documentation generation.
--- ## documentation (Documentation)
Implements the documentation package. ## Modules * [​`documentation`](/mojo/stdlib/documentation/documentation/): Provides decorators and utilities for interacting with Mojo documentation generation and validation.
--- ## block
Compatibility wrapper for gpu.block module. This module has been moved to gpu.primitives.block. This file provides backward compatibility for existing code that imports from gpu.block. DEPRECATED: Import from gpu.primitives.block instead.
--- ## cluster
GPU cluster operations (deprecated - use `gpu.primitives.cluster` or `gpu`). This module is deprecated. For new code, import cluster operations from the `gpu` package or `gpu.primitives.cluster` module: ```mojo # Deprecated: from gpu.cluster import cluster_sync, cluster_arrive # Recommended (import from top-level gpu package): from gpu import cluster_sync, cluster_arrive # Or import the module: from gpu.primitives import cluster ``` This module provides cluster-level synchronization operations for NVIDIA SM90+ GPUs (Hopper architecture and newer).
--- ## arch (Arch)
Architecture-specific MMA implementations. This package contains GPU architecture-specific implementations of matrix multiply-accumulate (MMA) operations: * **mma\_nvidia**: NVIDIA tensor cores (SM70-SM90) - Volta through Hopper * **mma\_nvidia\_sm100**: NVIDIA Blackwell (SM100) tensor cores - 5th gen tensor cores * **mma\_amd**: AMD Matrix Cores (CDNA2/3/4) - Data center GPUs * **mma\_amd\_rdna**: AMD WMMA (RDNA3/4) - Consumer GPUs ## Module Organization Each architecture module contains: * Private implementation functions (prefixed with `_`) * Architecture-specific intrinsic calls * Data type conversions specific to that architecture ## Usage These modules should **not** be imported directly by user code. Instead, use the unified interface in `gpu.compute.mma` which automatically dispatches to the appropriate architecture-specific implementation at compile time: ```mojo from gpu.compute import mma # Automatically dispatches to the correct architecture result = mma(a, b, c) ``` ## Internal Implementation Details The main `gpu.compute.mma` module imports these implementations: ```mojo from .arch.mma_nvidia import _mma_nvidia from .arch.mma_amd import _mma_amd ``` And dispatches based on compile-time architecture detection: ```mojo @parameter if is_nvidia_gpu(): _mma_nvidia(d, a, b, c) elif is_amd_gpu(): _mma_amd[block_size](d, a, b, c) ``` ## Modules * [​`mma_amd`](/mojo/stdlib/gpu/compute/arch/mma_amd/): AMD CDNA Matrix Cores implementation for matrix multiply-accumulate operations. * [​`mma_amd_rdna`](/mojo/stdlib/gpu/compute/arch/mma_amd_rdna/): AMD RDNA3/4 WMMA implementation for matrix multiply-accumulate operations. * [​`mma_nvidia`](/mojo/stdlib/gpu/compute/arch/mma_nvidia/): NVIDIA Tensor Cores implementation for matrix multiply-accumulate operations. * [​`mma_nvidia_sm100`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/): This module includes utilities for working with the SM100 MMA instructions.
--- ## mma_amd
AMD CDNA Matrix Cores implementation for matrix multiply-accumulate operations. This module provides MMA implementations for AMD CDNA2, CDNA3, and CDNA4 data center GPUs using the MFMA (Matrix Fused Multiply-Add) instructions. Reference:
--- ## mma_amd_rdna
AMD RDNA3/4 WMMA implementation for matrix multiply-accumulate operations. This module provides MMA implementations for AMD RDNA3 and RDNA4 consumer GPUs using the WMMA (Wave Matrix Multiply Accumulate) instructions. Reference: ## Functions * [​`load_matrix_a_amd_rdna16x16x16`](/mojo/stdlib/gpu/compute/arch/mma_amd_rdna/load_matrix_a_amd_rdna16x16x16): Loads 16×16×16 matrix A tile for RDNA (Wave32) architecture. * [​`load_matrix_b_amd_rdna16x16x16`](/mojo/stdlib/gpu/compute/arch/mma_amd_rdna/load_matrix_b_amd_rdna16x16x16): Loads 16×16×16 matrix B tile for RDNA (Wave32) architecture.
--- ## load_matrix_a_amd_rdna16x16x16
`load_matrix_a_amd_rdna16x16x16(a_ptr: UnsafePointer[Float16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.float16, 16]` Loads 16×16×16 matrix A tile for RDNA (Wave32) architecture. This function is optimized for AMD RDNA GPUs (Radeon RX 7000 series) which use Wave32 execution mode. Each thread loads 16 contiguous FP16 elements using an access pattern appropriate for WMMA instructions. Notes: The concrete return type (SIMD\[16]) avoids type ambiguity and padding overhead. This function is architecture-specific for RDNA - for CDNA, use load\_matrix\_a\_amd\_cdna16x16x16() which returns SIMD\[4]. **Args:** * ​a\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix A data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix A (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 16 FP16 values for this thread. `load_matrix_a_amd_rdna16x16x16(a_ptr: UnsafePointer[BFloat16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.bfloat16, 16]` Loads 16×16×16 matrix A tile for RDNA (Wave32) architecture. This function is optimized for AMD RDNA GPUs (Radeon RX 7000 series) which use Wave32 execution mode. Each thread loads 16 contiguous BF16 elements using an access pattern appropriate for WMMA instructions. Notes: The concrete return type (SIMD\[16]) avoids type ambiguity and padding overhead. This function is architecture-specific for RDNA - for CDNA, use load\_matrix\_a\_amd\_cdna16x16x16() which returns SIMD\[4]. **Args:** * ​a\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix A data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix A (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 16 BF16 values for this thread.
--- ## load_matrix_b_amd_rdna16x16x16
`load_matrix_b_amd_rdna16x16x16(b_ptr: UnsafePointer[Float16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.float16, 16]` Loads 16×16×16 matrix B tile for RDNA (Wave32) architecture. This function is optimized for AMD RDNA GPUs (Radeon RX 7000 series) which use Wave32 execution mode. Each thread loads 16 contiguous FP16 elements using an access pattern appropriate for WMMA instructions. Notes: The concrete return type (SIMD\[16]) avoids type ambiguity and padding overhead. This function is architecture-specific for RDNA - for CDNA, use load\_matrix\_b\_amd\_cdna16x16x16() which returns SIMD\[4]. **Args:** * ​b\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix B data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix B (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 16 FP16 values for this thread. `load_matrix_b_amd_rdna16x16x16(b_ptr: UnsafePointer[BFloat16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.bfloat16, 16]` Loads 16×16×16 matrix B tile for RDNA (Wave32) architecture. This function is optimized for AMD RDNA GPUs (Radeon RX 7000 series) which use Wave32 execution mode. Each thread loads 16 contiguous BF16 elements using an access pattern appropriate for WMMA instructions. Notes: The concrete return type (SIMD\[16]) avoids type ambiguity and padding overhead. This function is architecture-specific for RDNA - for CDNA, use load\_matrix\_b\_amd\_cdna16x16x16() which returns SIMD\[4]. **Args:** * ​b\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix B data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix B (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 16 BF16 values for this thread.
--- ## mma_nvidia
NVIDIA Tensor Cores implementation for matrix multiply-accumulate operations. This module provides MMA implementations for NVIDIA GPUs with Tensor Cores, covering architectures from SM70 (Volta) through SM90 (Hopper). Supported operations: * FP16 accumulation (SM70+) * FP32 accumulation with FP16/BF16 inputs (SM80+) * TF32 operations (SM80+) * FP8 operations (SM89+) Reference:
--- ## MMASmemDescriptor
`@register_passable(trivial)` `struct MMASmemDescriptor` Descriptor for shared memory operands tcgen05 mma instructions. This struct represents a descriptor that encodes information about shared memory layout and access patterns for warp group matrix multiply operations. The descriptor contains the following bit fields: | Bit field | Size | Description | | --------- | ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | | 0-13 | 14 | Base address in shared memory | | 16-29 | 14 | LBO: leading dim byte offset | | 32-45 | 14 | SBO: stride dim byte offset | | 46-48 | 3 | Fixed constant value: 0b001 | | 49-51 | 3 | Matrix base offset, 0 for canonical layouts | | 52 | 1 | Leading dimension stride mode:  0: byte offset relative  1: byte address absolute(only used for 48B K tile) | | 53-60 | 8 | Fixed constant value: 0 | | 61-63 | 3 | Swizzle mode:  0: No swizzling  1: 128-Byte with 32B atomic swizzling  2: 128-Byte swizzling  4: 64-Byte swizzling  6: 32-Byte swizzling | Note: * Some bits are unused. * Base address, LBO, and SBO ignore 4 least significant bits. See ## Fields * ​desc (`UInt64`): The 64-bit descriptor encodes shared memory operand information. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MMAOperandDescriptor`](/mojo/stdlib/gpu/compute/mma_operand_descriptor/MMAOperandDescriptor), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `mask_14_bits` `comptime mask_14_bits = 16383` Mask with the lower 14 bits set. ## Methods ### `__init__` `__init__(val: UInt64) -> Self` Initialize descriptor with raw 64-bit value. This constructor allows creating a descriptor directly from a 64-bit integer that already contains the properly formatted bit fields for the descriptor. The implicit attribute enables automatic conversion from `UInt64` to `MMASmemDescriptor`. **Args:** * ​val ([`UInt64`](/mojo/stdlib/builtin/simd/#uint64)): A 64-bit integer containing the complete descriptor bit layout. ### `__add__` `__add__(self, offset: Int) -> Self` Add offset to descriptor's base address. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Byte offset to add to base address. **Returns:** `Self`: New descriptor with updated base address. ### `__iadd__` `__iadd__(mut self, offset: Int)` Add offset to descriptor's base address in-place. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Byte offset to add to base address. ### `create` `static create[stride_byte_offset: Int, leading_byte_offset: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](smem_ptr: UnsafePointer[type, origin, address_space=AddressSpace.SHARED]) -> Self` Create a descriptor for shared memory operand. **Parameters:** * ​stride\_byte\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Stride dimension offset in bytes. * ​leading\_byte\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension stride in bytes. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern mode. **Args:** * ​smem\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to shared memory operand. **Returns:** `Self`: Initialized descriptor for the shared memory operand.
--- ## MMASmemDescriptorPair
`@register_passable(trivial)` `struct MMASmemDescriptorPair` Descriptor for shared memory operands tcgen05 mma instructions. This struct represents a descriptor that encodes information about shared memory layout and access patterns for warp group matrix multiply operations. The descriptor contains the following bit fields: | Bit field | Size | Description | | --------- | ---- | ------------------------------------------------------------------------------------------------------------------------------------------------------------ | | 0-13 | 14 | Base address in shared memory | | 16-29 | 14 | LBO: leading dim byte offset | | 32-45 | 14 | SBO: stride dim byte offset | | 46-48 | 3 | Fixed constant value: 0b001 | | 49-51 | 3 | Matrix base offset, 0 for canonical layouts | | 52 | 1 | Leading dimension stride mode:  0: byte offset relative  1: byte address absolute(only used for 48B K tile) | | 53-60 | 8 | Fixed constant value: 0 | | 61-63 | 3 | Swizzle mode:  0: No swizzling  1: 128-Byte with 32B atomic swizzling  2: 128-Byte swizzling  4: 64-Byte swizzling  6: 32-Byte swizzling | Note: * Some bits are unused. * Base address, LBO, and SBO ignore 4 least significant bits. See ## Fields * ​hi (`UInt32`): The low 32-bits of the descriptor. * ​lo (`UInt32`): The high 32-bits of the descriptor. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `mask_14_bits` `comptime mask_14_bits = 16383` Mask with the lower 14 bits set. ## Methods ### `__init__` `__init__(hi: UInt32, lo: UInt32) -> Self` Initialize descriptor with raw 64-bit value. This constructor allows creating a descriptor directly from a 64-bit integer that already contains the properly formatted bit fields for the descriptor. The implicit attribute enables automatic conversion from `UInt64` to `MMASmemDescriptor`. **Args:** * ​hi ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): A 32-bit integer containing the upper half of the descriptor layout. * ​lo ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): A 32-bit integer containing the lower half of the descriptor layout. ### `__add__` `__add__(self, offset: UInt32) -> Self` Add offset to descriptor's base address. **Args:** * ​offset ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Byte offset to add to base address. **Returns:** `Self`: New descriptor with updated base address. ### `__iadd__` `__iadd__(mut self, offset: UInt32)` Add offset to descriptor's base address in-place. **Args:** * ​offset ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Byte offset to add to base address. ### `create` `static create[stride_byte_offset: Int, leading_byte_offset: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](smem_ptr: UnsafePointer[type, origin, address_space=AddressSpace.SHARED]) -> Self` Create a descriptor for shared memory operand. **Parameters:** * ​stride\_byte\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Stride dimension offset in bytes. * ​leading\_byte\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension stride in bytes. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern mode. **Args:** * ​smem\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to shared memory operand. **Returns:** `Self`: Initialized descriptor for the shared memory operand.
--- ## UMMAInsDescriptor
`@register_passable(trivial)` `struct UMMAInsDescriptor[mma_kind: UMMAKind]` Descriptor for UMMA instructions. This struct represents a descriptor that encodes information about UMMA instructions. The descriptor contains the following bit fields: * Sparsity (2 bits): The sparsity of the input matrices. Currently defaults to dense matrices. * Saturate for integer types (1 bits): Whether to saturate the result for integer types. Currently not supported. * Matrix D type (2 bits): Data type of matrix D. * Matrix A type (3 bits): Data type of matrix A. * Matrix B type (3 bits): Data type of matrix B. * Negate A matrix (1 bit): Whether to negate matrix A. Currently defaults to False. * Negate B matrix (1 bit): Whether to negate matrix B. Currently defaults to False. * Transpose A (1 bit): Whether to transpose matrix A. * Transpose B (1 bit): Whether to transpose matrix B. * N, Dimension of Matrix B (6 bits): Number of columns in matrix B. 3 LSBs are unused. * M, Dimension of Matrix A (6 bits): Number of rows in matrix A. 3 LSBs are unused. See: ## Parameters * ​mma\_kind ([`UMMAKind`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAKind)): The kind of UMMA instruction. ## Fields * ​desc (`UInt32`): The 32-bit descriptor value that encodes UMMA instruction information. This field stores the complete descriptor with all bit fields packed into a single 32-bit integer: * Bits 0-1: Sparsity selector(2 bits) * Bits 2: Sparsity enable(1 bit) * Bits 3: Saturate for integer types (1 bit) * Bits 4-5: Matrix D type (2 bits) * Bits 6: Reserved (1 bit) * Bits 7-9: Matrix A type (3 bits) * Bits 10-12: Matrix B type (3 bits) * Bits 13: Negate A matrix (1 bit) * Bits 14: Negate B matrix (1 bit) * Bits 15: Transpose A (1 bit) * Bits 16: Transpose B (1 bit) * Bits 17-22: N, Dimension of Matrix B (6 bits) * Bits 23: Reserved (1 bit) * Bits 24-28: M, Dimension of Matrix A (5 bits) * Bits 29: Reserved (1 bit) * Bits 30-31: Maximum shift while attempting B matrix (2 bits) ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(value: UInt32) -> Self` Initialize descriptor with raw 32-bit value. This constructor allows creating a descriptor directly from a 32-bit integer that already contains the properly formatted bit fields for the descriptor. **Args:** * ​value ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): A 32-bit integer containing the complete descriptor bit layout. ### `create` `static create[d_type: DType, a_type: DType, b_type: DType, output_shape: IndexList[2, element_type=DType.uint32], /, *, transpose_a: Bool = False, transpose_b: Bool = True]() -> Self` Create a descriptor for UMMA instructions. This function creates a descriptor for UMMA instructions based on the provided parameters. **Parameters:** * ​d\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of matrix D. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of matrix A. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of matrix B. * ​output\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the output matrix. * ​transpose\_a ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to transpose matrix A. * ​transpose\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to transpose matrix B. **Returns:** `Self`: A 32-bit integer containing the complete descriptor bit layout. `static create[d_type: DType, a_type: DType, b_type: DType, scale_type: DType, output_shape: IndexList[2, element_type=DType.uint32], /, *, transpose_a: Bool = False, transpose_b: Bool = True]() -> Self` Create a descriptor for UMMA MXF8F6F4 instructions. This function creates a descriptor for UMMA MXF8F6F4 instructions based on the provided parameters. **Parameters:** * ​d\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of matrix D. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of matrix A. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of matrix B. * ​scale\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the scale factors (only applicable to MXF8F6F4). * ​output\_shape ([`IndexList`](/mojo/stdlib/utils/index_/IndexList)): The shape of the output matrix. * ​transpose\_a ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to transpose matrix A. * ​transpose\_b ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to transpose matrix B. **Returns:** `Self`: A 32-bit integer containing the complete descriptor bit layout. ### `update_desc_with_sf_id` `static update_desc_with_sf_id[sf_id: UInt32](inst_desc) -> Self` Update the descriptor with the scale factor ID. **Parameters:** * ​sf\_id ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The scale factor ID. **Args:** * ​inst\_desc (`Self`): The descriptor to update. **Returns:** `Self`: The updated descriptor.
--- ## UMMAKind
`@register_passable(trivial)` `struct UMMAKind` Struct for UMMA instruction types. This struct defines the different types of UMMA instructions that is supported by BlackWell. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `KIND_F16` `comptime KIND_F16 = UMMAKind(2)` F16 type. ### `KIND_F8F6F4` `comptime KIND_F8F6F4 = UMMAKind(3)` F8F6F4 type. ### `KIND_I8` `comptime KIND_I8 = UMMAKind(4)` I8 type. ### `KIND_MXF8F6F4` `comptime KIND_MXF8F6F4 = UMMAKind(5)` MXF8F6F4 type. ### `KIND_TF32` `comptime KIND_TF32 = UMMAKind(0)` TF32 type. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` Check if two UMMA kinds are equal. **Args:** * ​other (`Self`): The other UMMA kind to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the UMMA kinds are equal, False otherwise. ### `__ne__` `__ne__(self, other: Self) -> Bool` Check if two UMMA kinds are not equal. **Args:** * ​other (`Self`): The other UMMA kind to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the UMMA kinds are not equal, False otherwise. ### `__int__` `__int__(self) -> Int` Convert UMMA kind to an integer value. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integer value representing the UMMA instruction type. ### `__str__` `__str__(self) -> String` Convert UMMA kind to a string, this can be used as the instruction qualifier. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): The PTX qualifier representation of the UMMA kind. ### `write_to` `write_to(self, mut writer: T)` Write the UMMA kind to a writer. **Args:** * ​writer (`T`): The writer to write the UMMA kind to.
--- ## mma_nvidia_sm100
This module includes utilities for working with the SM100 MMA instructions. ## Structs * [​`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor): Descriptor for shared memory operands tcgen05 mma instructions. * [​`MMASmemDescriptorPair`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptorPair): Descriptor for shared memory operands tcgen05 mma instructions. * [​`UMMAInsDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAInsDescriptor): Descriptor for UMMA instructions. * [​`UMMAKind`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAKind): Struct for UMMA instruction types. ## Functions * [​`mma`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/mma): Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction. * [​`mma_arrive`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/mma_arrive): Arrive at the mbar pointer for the MMA instruction. * [​`mma_arrive_multicast`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/mma_arrive_multicast): Arrive at the mbar pointer for the MMA instruction for multiple ctas.
--- ## mma (Mma_nvidia_sm100)
`mma[kind: UMMAKind, //, cta_group: Int = 1, /, *, c_scale: UInt32 = 1](a_desc: MMASmemDescriptor, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind])` Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction. **Parameters:** * ​kind ([`UMMAKind`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAKind)): Data type of the matrices. * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of ctas used by MMA. * ​c\_scale ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Scale factor for the C matrix, 0 or 1. **Args:** * ​a\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): The descriptor for the A matrix. * ​b\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): The descriptor for the B matrix. * ​c\_tmem ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the C matrix in the tensor memory. * ​inst\_desc ([`UMMAInsDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAInsDescriptor)): The descriptor for the MMA instruction. `mma[kind: UMMAKind, //, cta_group: Int = 1, /](a_desc: MMASmemDescriptor, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind], sfa_tmem: UInt32, sfb_tmem: UInt32, c_scale: UInt32)` Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction. **Parameters:** * ​kind ([`UMMAKind`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAKind)): Data type of the matrices. * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of ctas used by MMA. **Args:** * ​a\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): The descriptor for the A matrix. * ​b\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): The descriptor for the B matrix. * ​c\_tmem ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the C matrix in the tensor memory. * ​inst\_desc ([`UMMAInsDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAInsDescriptor)): The descriptor for the MMA instruction. * ​sfa\_tmem ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the block scale factor A in the tensor memory. * ​sfb\_tmem ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the block scale factor B in the tensor memory. * ​c\_scale ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Scale factor for the C matrix, 0 or 1. `mma[kind: UMMAKind, //, cta_group: Int = 1, /](a_desc: MMASmemDescriptor, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind], c_scale: UInt32)` Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction. **Parameters:** * ​kind ([`UMMAKind`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAKind)): Data type of the matrices. * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of ctas used by MMA. **Args:** * ​a\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): The descriptor for the A matrix. * ​b\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): The descriptor for the B matrix. * ​c\_tmem ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the C matrix in the tensor memory. * ​inst\_desc ([`UMMAInsDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAInsDescriptor)): The descriptor for the MMA instruction. * ​c\_scale ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Scale factor for the C matrix. Any non-zero value is translated to `1`. `mma[kind: UMMAKind, //, cta_group: Int = 1, /](a_desc: UInt32, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind], c_scale: UInt32)` Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction. **Parameters:** * ​kind ([`UMMAKind`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAKind)): Data type of the matrices. * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of ctas used by MMA. **Args:** * ​a\_desc ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The descriptor for the A matrix. * ​b\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): The descriptor for the B matrix. * ​c\_tmem ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the C matrix in the tensor memory. * ​inst\_desc ([`UMMAInsDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAInsDescriptor)): The descriptor for the MMA instruction. * ​c\_scale ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Scale factor for the C matrix. Any non-zero value is interpreted as `1`. `mma[kind: UMMAKind, //, cta_group: Int = 1, /, *, c_scale: UInt32 = 1](a_desc: UInt32, b_desc: MMASmemDescriptor, c_tmem: UInt32, inst_desc: UMMAInsDescriptor[kind])` Perform a matrix multiply-accumulate operation using the tcgen05.mma instruction. **Parameters:** * ​kind ([`UMMAKind`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAKind)): Data type of the matrices. * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of ctas used by MMA. * ​c\_scale ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Scale factor for the C matrix, 0 or 1. **Args:** * ​a\_desc ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The descriptor for the A matrix. * ​b\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): The descriptor for the B matrix. * ​c\_tmem ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the C matrix in the tensor memory. * ​inst\_desc ([`UMMAInsDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/UMMAInsDescriptor)): The descriptor for the MMA instruction.
--- ## mma_arrive
`mma_arrive[cta_group: Int = 1](mbar_ptr: UnsafePointer[type, origin, address_space=AddressSpace.SHARED])` Arrive at the mbar pointer for the MMA instruction. **Parameters:** * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of ctas used by MMA. **Args:** * ​mbar\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to the mbar.
--- ## mma_arrive_multicast
`mma_arrive_multicast[cta_group: Int = 1](mbar_ptr: UnsafePointer[type, origin, address_space=AddressSpace.SHARED], cta_mask: UInt16)` Arrive at the mbar pointer for the MMA instruction for multiple ctas. **Parameters:** * ​cta\_group ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of ctas used by MMA. **Args:** * ​mbar\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to the mbar. * ​cta\_mask ([`UInt16`](/mojo/stdlib/builtin/simd/#uint16)): Mask of ctas to signal.
--- ## compute
GPU compute operations package - MMA and tensor core operations. This package provides GPU tensor core and matrix multiplication operations: * **mma**: Unified warp matrix-multiply-accumulate (WMMA) operations * **mma\_util**: Utility functions for loading/storing MMA operands * **mma\_operand\_descriptor**: Operand descriptor types for MMA * **tensor\_ops**: Tensor core-based reductions and operations * **tcgen05**: 5th generation tensor core operations (Blackwell) * **arch/**: Architecture-specific MMA implementations (internal) * `mma_nvidia`: NVIDIA tensor cores (SM70-SM90) * `mma_nvidia_sm100`: NVIDIA Blackwell (SM100) * `mma_amd`: AMD Matrix Cores (CDNA2/3/4) * `mma_amd_rdna`: AMD WMMA (RDNA3/4) ## Usage Import compute operations directly: ```mojo from gpu.compute import mma # Automatically dispatches to the correct GPU architecture result = mma.mma(a, b, c) ``` Architecture-specific implementations in `arch/` are internal and should not be imported directly by user code. ## Packages * [​`arch`](/mojo/stdlib/gpu/compute/arch/): Architecture-specific MMA implementations. ## Modules * [​`mma`](/mojo/stdlib/gpu/compute/mma/): This module includes utilities for working with the warp-matrix-matrix-multiplication (wmma) instructions. * [​`mma_operand_descriptor`](/mojo/stdlib/gpu/compute/mma_operand_descriptor/): * [​`mma_util`](/mojo/stdlib/gpu/compute/mma_util/): Matrix multiply accumulate (MMA) utilities for GPU tensor cores. * [​`tcgen05`](/mojo/stdlib/gpu/compute/tcgen05/): This module includes utilities for working with the tensorcore 5th generation (tcgen05) instructions. * [​`tensor_ops`](/mojo/stdlib/gpu/compute/tensor_ops/): This module provides tensor core operations and utilities for GPU computation.
--- ## WGMMADescriptor
`@register_passable(trivial)` `struct WGMMADescriptor[dtype: DType]` Descriptor for shared memory operands used in warp group matrix multiply operations. This struct represents a descriptor that encodes information about shared memory layout and access patterns for warp group matrix multiply operations. The descriptor contains the following bit fields: | Bit field | Size | Description | | --------- | ---- | ---------------------------------------------------------------------------------------------------- | | 0-13 | 14 | Base address in shared memory | | 16-29 | 14 | LBO: leading dim byte offset | | 32-45 | 14 | SBO: stride dim byte offset | | 49-51 | 3 | Matrix base offset, 0 for canonical layouts | | 62-63 | 2 | Swizzle mode:   0: no swizzle,   1: 128B swizzle,   2: 64B swizzle,   3: 32B swizzle | Note: * Some bits are unused. * Base address, LBO, and SBO ignore 4 least significant bits. See: ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the shared memory operand. This affects memory alignment and access patterns for the descriptor. ## Fields * ​desc (`Int64`): The 64-bit descriptor value that encodes shared memory layout information. This field stores the complete descriptor with all bit fields packed into a single 64-bit integer: * Bits 0-13: Base address in shared memory (14 bits) * Bits 16-29: Leading dimension stride in bytes (14 bits) * Bits 32-45: Stride dimension offset in bytes (14 bits) * Bits 49-51: Base offset (3 bits) * Bits 62-63: Swizzle mode for memory access pattern (2 bits) The descriptor is used by NVIDIA Hopper architecture's warp group matrix multiply instructions to efficiently access shared memory with the appropriate layout and access patterns. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`MMAOperandDescriptor`](/mojo/stdlib/gpu/compute/mma_operand_descriptor/MMAOperandDescriptor), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(val: Int64) -> Self` Initialize descriptor with raw 64-bit value. This constructor allows creating a descriptor directly from a 64-bit integer that already contains the properly formatted bit fields for the descriptor. The implicit attribute enables automatic conversion from `Int64` to `WGMMADescriptor`. **Args:** * ​val ([`Int64`](/mojo/stdlib/builtin/simd/#int64)): A 64-bit integer containing the complete descriptor bit layout. ### `__add__` `__add__(self, offset: Int) -> Self` Add offset to descriptor's base address. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Byte offset to add to base address. **Returns:** `Self`: New descriptor with updated base address. ### `__iadd__` `__iadd__(mut self, offset: Int)` Add offset to descriptor's base address in-place. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Byte offset to add to base address. ### `create` `static create[stride_byte_offset: Int, leading_byte_offset: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle.SWIZZLE_NONE](smem_ptr: UnsafePointer[Scalar[dtype], origin, address_space=AddressSpace.SHARED]) -> Self` Create a descriptor for shared memory operand. **Parameters:** * ​stride\_byte\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Stride dimension offset in bytes. * ​leading\_byte\_offset ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension stride in bytes. * ​swizzle\_mode ([`TensorMapSwizzle`](/mojo/stdlib/gpu/host/nvidia/tma/TensorMapSwizzle)): Memory access pattern mode. **Args:** * ​smem\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to shared memory operand. **Returns:** `Self`: Initialized descriptor for the shared memory operand.
--- ## get_amd_bf8_dtype
`get_amd_bf8_dtype() -> DType` Gets the appropriate BF8 dtype for the current AMD GPU architecture. **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType): - `DType.float8_e5m2` for CDNA4+ and RDNA4+ GPUs * `DType.float8_e5m2fnuz` for CDNA1-3 GPUs * `DType.invalid` for RDNA3 (no native BF8 support).
--- ## get_amd_fp8_dtype
`get_amd_fp8_dtype() -> DType` Gets the appropriate FP8 dtype for the current AMD GPU architecture. **Returns:** [`DType`](/mojo/stdlib/builtin/dtype/DType): - `DType.float8_e4m3fn` for CDNA4+ and RDNA4+ GPUs * `DType.float8_e4m3fnuz` for CDNA1-3 GPUs * `DType.invalid` for RDNA3 (no native FP8 support).
--- ## mma (4)
This module includes utilities for working with the warp-matrix-matrix-multiplication (wmma) instructions. ## Structs * [​`WGMMADescriptor`](/mojo/stdlib/gpu/compute/mma/WGMMADescriptor): Descriptor for shared memory operands used in warp group matrix multiply operations. ## Functions * [​`get_amd_bf8_dtype`](/mojo/stdlib/gpu/compute/mma/get_amd_bf8_dtype): Gets the appropriate BF8 dtype for the current AMD GPU architecture. * [​`get_amd_fp8_dtype`](/mojo/stdlib/gpu/compute/mma/get_amd_fp8_dtype): Gets the appropriate FP8 dtype for the current AMD GPU architecture. * [​`ld_matrix`](/mojo/stdlib/gpu/compute/mma/ld_matrix): Loads a matrix from shared memory into registers in a format suitable for tensor core operations. * [​`mma`](/mojo/stdlib/gpu/compute/mma/mma): Performs warp sync Tensor Core based Matrix-multiply and accumulate (MMA) operation. * [​`st_matrix`](/mojo/stdlib/gpu/compute/mma/st_matrix): Performs warp-synchronized copy from registers to shared memory. * [​`wgmma_async`](/mojo/stdlib/gpu/compute/mma/wgmma_async): Performs warp group async Matrix-multiply and accumulate (WGMMA) operation. * [​`wgmma_commit_group_sync`](/mojo/stdlib/gpu/compute/mma/wgmma_commit_group_sync): Commits pending warp group matrix multiply operations. * [​`wgmma_fence_aligned`](/mojo/stdlib/gpu/compute/mma/wgmma_fence_aligned): Inserts a memory fence for warp group matrix multiply operations. * [​`wgmma_wait_group_sync`](/mojo/stdlib/gpu/compute/mma/wgmma_wait_group_sync): Waits for all pending warp group matrix multiply operations to complete.
--- ## ld_matrix
`ld_matrix[dtype: DType, //, simd_width: Int, *, transpose: Bool = False](ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space]) -> SIMD[dtype, simd_width]` Loads a matrix from shared memory into registers in a format suitable for tensor core operations. This function performs a warp-synchronized load from shared memory to registers, formatting the data to be directly usable by tensor core Matrix Multiply-Accumulate (MMA) instructions. Note: * All threads in a warp must execute this operation together. * For transposed loads, only half precision (float16) is supported. * The register width is fixed at 4 bytes (32 bits). * Supported configurations: * x1: One 32-bit register per thread. * x2: Two 32-bit registers per thread. * x4: Four 32-bit registers per thread. Example: ```mojo from gpu.compute.mma import ld_matrix # Load 8x8 matrix of float16 values var data = ld_matrix[DType.float16, 8](ptr) # Load transposed matrix var transposed = ld_matrix[DType.float16, 8, transpose=True](ptr) ``` **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type of the matrix elements (e.g. float16, float32). * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the SIMD vector to load. * ​transpose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to transpose the matrix during load (only supported for half precision). **Args:** * ​ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to shared memory containing the source matrix data. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing the loaded matrix data, properly formatted for MMA operations.
--- ## mma (5)
`mma[block_size: Int = 1](mut d: SIMD[dtype, size], a: SIMD[dtype, size], b: SIMD[dtype, size], c: SIMD[dtype, size])` Performs warp sync Tensor Core based Matrix-multiply and accumulate (MMA) operation. This function executes a matrix multiply-accumulate operation using GPU Tensor Cores, synchronizing across the warp. It dispatches to architecture-specific implementations for NVIDIA and AMD GPUs. The operation performed is: d = (a \* b) + c Supported configurations depend on the GPU architecture: * NVIDIA: Various combinations of FP32, FP16, BF16, and FP8 formats * AMD: Limited subset of FP32 and FP16 operations Note: * All threads in a warp must execute this operation together * Input matrices must be properly loaded and formatted for Tensor Core operations * Matrix dimensions and data types must match hardware requirements **Parameters:** * ​block\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of the block of the MMA operation (e.g., 4x4x4\_16B). Applies to AMD GPUs only. **Args:** * ​d ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Output SIMD vector to store the result. * ​a ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): First input matrix as SIMD vector. * ​b ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Second input matrix as SIMD vector. * ​c ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Accumulator matrix as SIMD vector.
--- ## st_matrix
`st_matrix[dtype: DType, //, simd_width: Int, *, transpose: Bool = False](ptr: UnsafePointer[Scalar[dtype], origin, address_space=AddressSpace.SHARED], d: SIMD[DType.float32, simd_width])` Performs warp-synchronized copy from registers to shared memory. This function stores data from registers to shared memory in a format that can be directly used by tensor core Matrix Multiply-Accumulate (MMA) instructions. It uses the NVIDIA stmatrix instruction to perform an efficient warp-synchronized store. Note: The function performs a warp-synchronized operation - all threads in the warp must execute this instruction to avoid deadlock. **Constraints:** * Must be used with shared memory pointers. * Number of registers must be 1, 2, or 4. * Data must be properly aligned for matrix operations. * All threads in warp must participate. * Only supported on NVIDIA GPUs with tensor core capabilities. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of elements to store. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): Width of the SIMD vector. * ​transpose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): If True, transposes the matrix during store. **Args:** * ​ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to shared memory where data will be stored. * ​d ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): SIMD vector containing the data to store.
--- ## wgmma_async
`wgmma_async[m: Int, n: Int, k: Int, c_dtype: DType, width: Int, /, *, a_type: DType, b_type: DType, accum_type: DType = c_dtype, layout_a: StringSlice[StaticConstantOrigin] = "row", layout_b: StringSlice[StaticConstantOrigin] = "col", scale_d: Int = 1, scale_a: Int = 1, scale_b: Int = 1](mat_a_desc: WGMMADescriptor[dtype], mat_b_desc: WGMMADescriptor[dtype], c_reg: StaticTuple[Scalar[c_dtype], width]) -> StaticTuple[Scalar[c_dtype], width]` Performs warp group async Matrix-multiply and accumulate (WGMMA) operation. This function executes an asynchronous matrix multiplication using warp group MMA instructions. It supports various data types including tensor float32, bfloat16, float16, float8, int8, and uint8. **Constraints:** * The number of output registers must match the instruction shape: `(m * n // 128) * size_of(accum_type) == width * size_of(c_dtype)`. * Data type combinations must be compatible with hardware WGMMA instructions. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in matrix A and output matrix. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in matrix B and output matrix. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in matrix A / rows in matrix B. * ​c\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the output matrix C. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): Width of the InlineArray register for matrix C. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of matrix A. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of matrix B. * ​accum\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Accumulation data type (defaults to c\_dtype). * ​layout\_a ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Memory layout for matrix A ("row" or "col"). * ​layout\_b ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Memory layout for matrix B ("row" or "col"). * ​scale\_d ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix C. * ​scale\_a ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix A. * ​scale\_b ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix B. **Args:** * ​mat\_a\_desc ([`WGMMADescriptor`](/mojo/stdlib/gpu/compute/mma/WGMMADescriptor)): WGMMA descriptor for matrix A. * ​mat\_b\_desc ([`WGMMADescriptor`](/mojo/stdlib/gpu/compute/mma/WGMMADescriptor)): WGMMA descriptor for matrix B. * ​c\_reg ([`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple)): StaticTuple containing matrix C values. **Returns:** [`StaticTuple`](/mojo/stdlib/utils/static_tuple/StaticTuple): `StaticTuple` containing the result of the matrix multiplication. `wgmma_async[m: Int, n: Int, k: Int, c_dtype: DType, width: Int, /, *, a_type: DType, b_type: DType, accum_type: DType = c_dtype, layout_a: StringSlice[StaticConstantOrigin] = "row", layout_b: StringSlice[StaticConstantOrigin] = "col", scale_d: Int = 1, scale_a: Int = 1, scale_b: Int = 1](mat_a_desc: WGMMADescriptor[dtype], mat_b_desc: WGMMADescriptor[dtype], c_reg: SIMD[c_dtype, width]) -> SIMD[c_dtype, width]` Performs warp group async Matrix-multiply and accumulate (WGMMA) operation. This function executes an asynchronous matrix multiplication using warp group MMA instructions. It supports various data types including tensor float32, bfloat16, float16, float8, int8, and uint8. **Constraints:** * The number of output registers must match the instruction shape: `(m * n // 128) * size_of(accum_type) == width * size_of(c_dtype)`. * Data type combinations must be compatible with hardware WGMMA instructions. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in matrix A and output matrix. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in matrix B and output matrix. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in matrix A / rows in matrix B. * ​c\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the output matrix C. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): Width of the SIMD register for matrix C. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of matrix A. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of matrix B. * ​accum\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Accumulation data type (defaults to c\_dtype). * ​layout\_a ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Memory layout for matrix A ("row" or "col"). * ​layout\_b ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Memory layout for matrix B ("row" or "col"). * ​scale\_d ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix C. * ​scale\_a ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix A. * ​scale\_b ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix B. **Args:** * ​mat\_a\_desc ([`WGMMADescriptor`](/mojo/stdlib/gpu/compute/mma/WGMMADescriptor)): WGMMA descriptor for matrix A. * ​mat\_b\_desc ([`WGMMADescriptor`](/mojo/stdlib/gpu/compute/mma/WGMMADescriptor)): WGMMA descriptor for matrix B. * ​c\_reg ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): SIMD register containing matrix C values. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD register containing the result of the matrix multiplication. `wgmma_async[m: Int, n: Int, k: Int, a_dtype: DType, c_dtype: DType, frag_a_width: Int, frag_c_width: Int, /, *, a_type: DType, b_type: DType, accum_type: DType = c_dtype, layout_a: StringSlice[StaticConstantOrigin] = "row", layout_b: StringSlice[StaticConstantOrigin] = "col", scale_d: Int = 1, scale_a: Int = 1, scale_b: Int = 1](mat_a_frag: SIMD[a_dtype, frag_a_width], mat_b_desc: WGMMADescriptor[dtype], c: SIMD[c_dtype, frag_c_width]) -> SIMD[c_dtype, frag_c_width]` Performs warp group async Matrix-multiply and accumulate (WGMMA) operation. Currently only supports: * m=64, k=16. * BF16 input types. * FP32 accumulation. * Row major matrix A. * Column major matrix B (or row major for BF16). **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in output matrix. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in output matrix. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. * ​a\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of matrix A fragment. * ​c\_dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of output matrix C. * ​frag\_a\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): Width of matrix A fragment. * ​frag\_c\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): Width of output matrix C fragment. * ​a\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of matrix A. * ​b\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of matrix B. * ​accum\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type used for accumulation (defaults to c\_dtype). * ​layout\_a ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Layout of matrix A ("row" or "col", defaults to "row"). * ​layout\_b ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Layout of matrix B ("row" or "col", defaults to "col"). * ​scale\_d ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for output matrix C (defaults to 1). * ​scale\_a ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix A (defaults to 1). * ​scale\_b ([`Int`](/mojo/stdlib/builtin/int/Int)): Scale factor for matrix B (defaults to 1). **Args:** * ​mat\_a\_frag ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Fragment containing matrix A data. * ​mat\_b\_desc ([`WGMMADescriptor`](/mojo/stdlib/gpu/compute/mma/WGMMADescriptor)): Descriptor for matrix B data. * ​c ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Fragment containing matrix C data. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): Updated matrix C fragment after WGMMA operation.
--- ## wgmma_commit_group_sync
`wgmma_commit_group_sync()` Commits pending warp group matrix multiply operations. This synchronizes the warp group and ensures all WGMMA operations have been committed. Must be called after a sequence of WGMMA operations before accessing results.
--- ## wgmma_fence_aligned
`wgmma_fence_aligned()` Inserts a memory fence for warp group matrix multiply operations. This ensures all prior shared memory accesses are visible before subsequent WGMMA operations. Must be called before starting a new sequence of WGMMA operations.
--- ## wgmma_wait_group_sync
`wgmma_wait_group_sync[group: Int = 0]()` Waits for all pending warp group matrix multiply operations to complete. This synchronizes the warp group and ensures all WGMMA operations have finished executing. Must be called after commit and before accessing results. **Parameters:** * ​group ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of pending wgmma-groups to wait until.
--- ## MMAOperandDescriptor
Trait for abstracting MMA (Matrix Multiply-Accumulate) operand descriptors. This trait defines the interface for WGMMA operand descriptors used in GPU matrix operations. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__copyinit__` is trivial. The implementation of `__copyinit__` is considered to be trivial if: * The struct has a compiler-generated trivial `__copyinit__` and all its fields have a trivial `__copyinit__` method. In practice, it means the value can be copied by copying the bits from one location to another without side effects. ### `__del__is_trivial` `comptime __del__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__del__` is trivial. The implementation of `__del__` is considered to be trivial if: * The struct has a compiler-generated trivial destructor and all its fields have a trivial `__del__` method. In practice, it means that the `__del__` can be considered as no-op. ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial` A flag (often compiler generated) to indicate whether the implementation of `__moveinit__` is trivial. The implementation of `__moveinit__` is considered to be trivial if: * The struct has a compiler-generated `__moveinit__` and all its fields have a trivial `__moveinit__` method. In practice, it means the value can be moved by moving the bits from one location to another without side effects. ## Required methods ### `__copyinit__` `__copyinit__(out self: _Self, existing: _Self, /)` Create a new instance of the value by copying an existing one. **Args:** * ​existing (`_Self`): The value to copy. **Returns:** `_Self` ### `__moveinit__` `__moveinit__(out self: _Self, deinit existing: _Self, /)` Create a new instance of the value by moving the value of another. **Args:** * ​existing (`_Self`): The value to move. **Returns:** `_Self` ### `__add__` `__add__(self: _Self, offset: Int) -> _Self` Adds an offset to the operand descriptor. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The offset to add to the descriptor. **Returns:** `_Self`: A new descriptor with the offset applied. ## Provided methods ### `copy` `copy(self: _Self) -> _Self` Explicitly construct a copy of self. **Returns:** `_Self`: A copy of this value.
--- ## mma_operand_descriptor
## Traits * [​`MMAOperandDescriptor`](/mojo/stdlib/gpu/compute/mma_operand_descriptor/MMAOperandDescriptor): Trait for abstracting MMA (Matrix Multiply-Accumulate) operand descriptors.
--- ## mma_util
Matrix multiply accumulate (MMA) utilities for GPU tensor cores. This module provides functions for loading matrix tiles from memory into registers and storing results back to memory when using tensor cores for matrix multiplication. It supports both NVIDIA and AMD GPUs with functions specialized for different data types (FP32, FP16, BF16). The key functions are: * load\_matrix\_a: Loads tiles from the first input matrix A * load\_matrix\_b: Loads tiles from the second input matrix B * store\_matrix\_d: Stores result tiles to the output matrix D Each function handles the specific memory access patterns required by the tensor core instructions on each GPU architecture. The tile sizes and data layouts match the hardware requirements documented in: NVIDIA PTX: AMD Matrix Cores: ## Functions * [​`load_matrix_a`](/mojo/stdlib/gpu/compute/mma_util/load_matrix_a): Loads a tile of matrix A from memory to registers for TF32 tensor core operations. * [​`load_matrix_a_amd`](/mojo/stdlib/gpu/compute/mma_util/load_matrix_a_amd): Loads a tile of matrix A from memory to registers for AMD FP32 tensor core operations. * [​`load_matrix_b`](/mojo/stdlib/gpu/compute/mma_util/load_matrix_b): Loads a tile of matrix B from memory to registers for TF32 tensor core operations. * [​`load_matrix_b_amd`](/mojo/stdlib/gpu/compute/mma_util/load_matrix_b_amd): Loads a tile of matrix B from memory to registers for AMD FP32 tensor core operations. * [​`store_matrix_d`](/mojo/stdlib/gpu/compute/mma_util/store_matrix_d): Stores matrix D tile from registers to memory after tensor core operation.
--- ## load_matrix_a
`load_matrix_a[m: Int, n: Int, k: Int](a_ptr: UnsafePointer[Float32, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.float32, 4]` Loads a tile of matrix A from memory to registers for TF32 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=8, k=8. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. **Args:** * ​a\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix A data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix A (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 4 TF32 values loaded from matrix A in the required order. `load_matrix_a[m: Int, n: Int, k: Int](a_ptr: UnsafePointer[Float16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.float16, 4]` Loads a tile of matrix A from memory to registers for FP16 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=8, k=8. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. **Args:** * ​a\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix A data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix A (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 4 FP16 values loaded from matrix A in the required order. `load_matrix_a[m: Int, n: Int, k: Int](a_ptr: UnsafePointer[BFloat16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.bfloat16, (k // 2)]` Loads a tile of matrix A from memory to registers for BF16 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=8, k=8 or m=16, n=8, k=16. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. **Args:** * ​a\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix A data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix A (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing k//2 BF16 values loaded from matrix A in the required order.
--- ## load_matrix_a_amd
`load_matrix_a_amd[m: Int, n: Int, k: Int](a_ptr: UnsafePointer[Float32, origin], tile_row: Int, tile_col: Int, ldm: Int) -> Float32` Loads a tile of matrix A from memory to registers for AMD FP32 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=16, k=4. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. **Args:** * ​a\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix A data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix A (stride between rows). **Returns:** `Float32`: SIMD vector containing 1 FP32 value loaded from matrix A. `load_matrix_a_amd[m: Int, n: Int, k: Int, n_blocks: Int = 1](a_ptr: UnsafePointer[Float16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.float16, 4]` Loads a tile of matrix A from memory to registers for AMD FP16 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=16, k=16 and n\_blocks=1 or m=4, n=4, k=4 and n\_blocks=16. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. * ​n\_blocks ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of blocks. **Args:** * ​a\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix A data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix A (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 4 FP16 values loaded from matrix A. `load_matrix_a_amd[m: Int, n: Int, k: Int, n_blocks: Int = 1](a_ptr: UnsafePointer[BFloat16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.bfloat16, 4]` Loads a tile of matrix A from memory to registers for AMD BF16 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=16, k=16 and n\_blocks=1 or m=4, n=4, k=4 and n\_blocks=16. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. * ​n\_blocks ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of blocks. **Args:** * ​a\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix A data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix A (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 4 BF16 values loaded from matrix A.
--- ## load_matrix_b
`load_matrix_b[m: Int, n: Int, k: Int](b_ptr: UnsafePointer[Float32, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.float32, 2]` Loads a tile of matrix B from memory to registers for TF32 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=8, k=8. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. **Args:** * ​b\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix B data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix B (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 2 TF32 values loaded from matrix B in the required order. `load_matrix_b[m: Int, n: Int, k: Int](b_ptr: UnsafePointer[Float16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.float16, 2]` Loads a tile of matrix B from memory to registers for FP16 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=8, k=8. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. **Args:** * ​b\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix B data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix B (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 2 FP16 values loaded from matrix B in the required order. `load_matrix_b[m: Int, n: Int, k: Int](b_ptr: UnsafePointer[BFloat16, origin], tile_row: Int, tile_col: Int, ldm: Int) -> SIMD[DType.bfloat16, (k // 4)]` Loads a tile of matrix B from memory to registers for BF16 tensor core operations. **Constraints:** The tile dimensions must be m=16, n=8, k=8 or m=16, n=8, k=16. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. **Args:** * ​b\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix B data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix B (stride between rows). **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing k//4 BF16 values loaded from matrix B in the required order.
--- ## load_matrix_b_amd
`load_matrix_b_amd[m: Int, n: Int, k: Int](b_ptr: UnsafePointer[Float32, origin], tile_row: Int, tile_col: Int, ldm: Int) -> Float32` Loads a tile of matrix B from memory to registers for AMD FP32 tensor core operations. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. **Args:** * ​b\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix B data in memory. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix B (stride between rows). **Returns:** `Float32`: SIMD vector containing 1 FP32 value loaded from matrix B. `load_matrix_b_amd[m: Int, n: Int, k: Int, n_blocks: Int = 1](b_ptr: UnsafePointer[Float16, origin], tile_row: Int, tile_col: Int, ldm: Int, tile_loops: Int = 1) -> SIMD[DType.float16, 4]` Loads a tile of matrix B from memory to registers for AMD FP16 tensor core operations. This function loads 4 consecutive FP16 values per thread from matrix B in a pattern optimized for AMD GPU tensor core operations. Each thread loads values based on its position within the warp. Performance: * Optimized for AMD GPU memory access patterns. * Uses thread ID to determine which elements to load. * Loads 4 consecutive elements per thread for efficient vectorization. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. * ​n\_blocks ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of blocks. **Args:** * ​b\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix B data in memory (FP16 format). * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix B (stride between rows). * ​tile\_loops ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of tile loops across matrix B's row dimension. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 4 FP16 values loaded from matrix B. `load_matrix_b_amd[m: Int, n: Int, k: Int, n_blocks: Int = 1](b_ptr: UnsafePointer[BFloat16, origin], tile_row: Int, tile_col: Int, ldm: Int, tile_loops: Int = 1) -> SIMD[DType.bfloat16, 4]` Loads a tile of matrix B from memory to registers for AMD BF16 tensor core operations. This function loads 4 consecutive BF16 values per thread from matrix B in a pattern optimized for AMD GPU tensor core operations. Each thread loads values based on its position within the warp. Performance: * Optimized for AMD GPU memory access patterns. * Uses thread ID to determine which elements to load. * Loads 4 consecutive elements per thread for efficient vectorization. **Parameters:** * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in the output matrix tile. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in the output matrix tile. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiplication. * ​n\_blocks ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of blocks. **Args:** * ​b\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to matrix B data in memory (BF16 format). * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension of matrix B (stride between rows). * ​tile\_loops ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of tile loops across matrix B's row dimension. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing 4 BF16 values loaded from matrix B.
--- ## store_matrix_d
`store_matrix_d[dtype: DType, //, m: Int, n: Int, k: Int, n_blocks: Int = 1](d_ptr: UnsafePointer[Scalar[dtype], origin], d: SIMD[dtype, 4], tile_row: Int, tile_col: Int, ldm: Int)` Stores matrix D tile from registers to memory after tensor core operation. This function dispatches to architecture-specific implementations for storing the results of a tensor core matrix multiply-accumulate operation. It handles the different memory layouts required by NVIDIA and AMD tensor cores. Note: * Automatically selects appropriate implementation based on GPU architecture. * Each thread stores 4 elements in architecture-specific positions. * Must be called by all threads in a warp. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type of the matrix elements. * ​m ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of rows in matrix D. * ​n ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of columns in matrix D. * ​k ([`Int`](/mojo/stdlib/builtin/int/Int)): Inner dimension for matrix multiply. * ​n\_blocks ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of blocks. **Args:** * ​d\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to destination memory for matrix D. * ​d ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): SIMD vector containing 4 elements to store. * ​tile\_row ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting row index of the tile in matrix D. * ​tile\_col ([`Int`](/mojo/stdlib/builtin/int/Int)): Starting column index of the tile in matrix D. * ​ldm ([`Int`](/mojo/stdlib/builtin/int/Int)): Leading dimension (stride) of matrix D.
--- ## TensorMemory
`@register_passable(trivial)` `struct TensorMemory` A wrapper around tensor memory allocated for tcgen05 instructions. ## Fields * ​ptr (`UnsafePointer[UInt32, MutAnyOrigin, address_space=AddressSpace.SHARED]`): Pointer to the tensor memory address. * ​num\_cols (`UInt32`): The number of columns in the tensor memory. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `__init__(num_cols: UInt32) -> Self` Initialize the TensorMemory struct. **Args:** * ​num\_cols ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The number of columns to allocate.
--- ## tcgen05
This module includes utilities for working with the tensorcore 5th generation (tcgen05) instructions. ## `comptime` values ### `check_blackwell_constraint` `comptime check_blackwell_constraint = constrained[_has_blackwell_tcgen05(), "The tcgen05 instructions are only applicable on nVidia Blackwell (sm_100a, sm_101a) hardware.", ?]` Compile-time constraint ensuring Blackwell hardware is targeted. ## Structs * [​`TensorMemory`](/mojo/stdlib/gpu/compute/tcgen05/TensorMemory): A wrapper around tensor memory allocated for tcgen05 instructions. ## Functions * [​`tcgen05_alloc`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_alloc): Allocates tensor memory for use with tcgen05 instructions. * [​`tcgen05_cp`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_cp): Copies data from shared memory described by the matrix descriptor `s_desc` to tensor memory `tmem_addr`. * [​`tcgen05_dealloc`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_dealloc): Deallocates tensor memory allocated by tcgen05\_alloc(). * [​`tcgen05_fence_after`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_fence_after): Orders all the subsequent asynchronous `tcgen05` operations. * [​`tcgen05_fence_before`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_fence_before): Orders all the prior asynchronous `tcgen05` operations. * [​`tcgen05_ld`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_ld): Loads data from tensor memory into registers. * [​`tcgen05_load_wait`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_load_wait): Waits for tensor memory loads to complete. * [​`tcgen05_release_allocation_lock`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_release_allocation_lock): Releases the allocation lock for the current CTA group. * [​`tcgen05_st`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_st): Stores data from registers into tensor memory. * [​`tcgen05_store_wait`](/mojo/stdlib/gpu/compute/tcgen05/tcgen05_store_wait): Waits for tensor memory stores to complete.
--- ## tcgen05_alloc
`tcgen05_alloc[cta_group: Int32](ptr_tmem_addr: UnsafePointer[UInt32, origin, address_space=AddressSpace.SHARED], num_cols: UInt32)` Allocates tensor memory for use with tcgen05 instructions. Note: This function is only available on NVIDIA Blackwell GPUs (SM 100+). **Parameters:** * ​cta\_group ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): The cooperative thread array (CTA) group ID. **Args:** * ​ptr\_tmem\_addr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Shared memory pointer to hold tensor memory address. * ​num\_cols ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The number of columns to allocate.
--- ## tcgen05_cp
`tcgen05_cp[*, cta_group: Int32, datapaths: Int, bits: Int, src_fmt: String = "", dst_fmt: String = "", multicast: String = ""](tmem_addr: UInt32, s_desc: MMASmemDescriptor)` Copies data from shared memory described by the matrix descriptor `s_desc` to tensor memory `tmem_addr`. Note: This function is only available on NVIDIA Blackwell GPUs (SM 100+). **Parameters:** * ​cta\_group ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): The cooperative thread array (CTA) group ID. * ​datapaths ([`Int`](/mojo/stdlib/builtin/int/Int)): The first dimension of the shape. * ​bits ([`Int`](/mojo/stdlib/builtin/int/Int)): The second dimension of the shape. * ​src\_fmt ([`String`](/mojo/stdlib/collections/string/string/String)): Source format string. * ​dst\_fmt ([`String`](/mojo/stdlib/collections/string/string/String)): Destination format string. * ​multicast ([`String`](/mojo/stdlib/collections/string/string/String)): Multicast string. **Args:** * ​tmem\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Address of the tensor memory. * ​s\_desc ([`MMASmemDescriptor`](/mojo/stdlib/gpu/compute/arch/mma_nvidia_sm100/MMASmemDescriptor)): Matrix descriptor for the copy operation.
--- ## tcgen05_dealloc
`tcgen05_dealloc[cta_group: Int32](tmem_addr: UInt32, num_cols: UInt32)` Deallocates tensor memory allocated by tcgen05\_alloc(). This function deallocates tensor memory that was previously allocated using tcgen05\_alloc(). The deallocation must be performed by the same CTA group that performed the allocation. **Parameters:** * ​cta\_group ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): The cooperative thread array (CTA) group ID. **Args:** * ​tmem\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Address of the tensor memory to deallocate. * ​num\_cols ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Number of columns in the tensor memory.
--- ## tcgen05_fence_after
`tcgen05_fence_after()` Orders all the subsequent asynchronous `tcgen05` operations. Note: This function is only available on NVIDIA Blackwell GPUs (SM 100+).
--- ## tcgen05_fence_before
`tcgen05_fence_before()` Orders all the prior asynchronous `tcgen05` operations. Note: This function is only available on NVIDIA Blackwell GPUs (SM 100+).
--- ## tcgen05_ld
`tcgen05_ld[*, datapaths: Int, bits: Int, repeat: Int, dtype: DType, pack: Bool, width: Int = (((datapaths * bits) * repeat) // 1024)](tmem_addr: UInt32) -> SIMD[dtype, width]` Loads data from tensor memory into registers. **Parameters:** * ​datapaths ([`Int`](/mojo/stdlib/builtin/int/Int)): The first dimension of the shape. * ​bits ([`Int`](/mojo/stdlib/builtin/int/Int)): The second dimension of the shape. * ​repeat ([`Int`](/mojo/stdlib/builtin/int/Int)): The repeat factor. * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type to load. * ​pack ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to pack two 16-bit chunks of adjacent columns into a single 32-bit register. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The number elements in the result vector. **Args:** * ​tmem\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the tensor memory to load from. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): The SIMD register containing the loaded data.
--- ## tcgen05_load_wait
`tcgen05_load_wait()` Waits for tensor memory loads to complete. Note: This function is only available on NVIDIA Blackwell GPUs (SM 100+).
--- ## tcgen05_release_allocation_lock
`tcgen05_release_allocation_lock[cta_group: Int32]()` Releases the allocation lock for the current CTA group. Note: This function is only available on NVIDIA Blackwell GPUs (SM 100+). **Parameters:** * ​cta\_group ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): The cooperative thread array (CTA) group ID.
--- ## tcgen05_st
`tcgen05_st[dtype: DType, width: Int, //, *, datapaths: Int, bits: Int, repeat: Int, pack: Bool](tmem_addr: UInt32, data: SIMD[dtype, width])` Stores data from registers into tensor memory. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type to store. * ​width ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements in the data vector. * ​datapaths ([`Int`](/mojo/stdlib/builtin/int/Int)): The first dimension of the shape. * ​bits ([`Int`](/mojo/stdlib/builtin/int/Int)): The second dimension of the shape. * ​repeat ([`Int`](/mojo/stdlib/builtin/int/Int)): The repeat factor. * ​pack ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to pack two 16-bit chunks of adjacent columns into a single 32-bit register. **Args:** * ​tmem\_addr ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The address of the tensor memory to store to. * ​data ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): The data to store into the tensor memory.
--- ## tcgen05_store_wait
`tcgen05_store_wait()` Waits for tensor memory stores to complete. Note: This function is only available on NVIDIA Blackwell GPUs (SM 100+).
--- ## tensor_ops
This module provides tensor core operations and utilities for GPU computation. The module includes functions for: * Tensor core based reductions (tc\_reduce) supporting various data types and SIMD widths * GEVM (General Matrix-Vector Multiplication) reductions using tensor cores * Efficient warp-level reductions leveraging tensor core operations The tensor core operations are optimized for NVIDIA GPUs and support different data types including float32, float16, and bfloat16. The module provides both scalar and vector variants of reduction operations with different SIMD widths for maximum performance. Key functions: * tc\_reduce: Main tensor core reduction function supporting various types and widths * tc\_reduce\_gevm\_8x: 8x GEVM reduction using tensor cores * tc\_reduce\_gevm\_4x: 4x GEVM reduction using tensor cores Note: Most operations require NVIDIA GPUs with tensor core support. Operations are optimized for warp-level execution. ## Functions * [​`tc_reduce`](/mojo/stdlib/gpu/compute/tensor_ops/tc_reduce): Performs tensor core based reduction on a SIMD vector. * [​`tc_reduce_gevm_4x`](/mojo/stdlib/gpu/compute/tensor_ops/tc_reduce_gevm_4x): Performs a 4x GEVM reduction using tensor cores. * [​`tc_reduce_gevm_8x`](/mojo/stdlib/gpu/compute/tensor_ops/tc_reduce_gevm_8x): Performs an 8x GEVM reduction using tensor cores.
--- ## tc_reduce
`tc_reduce[in_type: DType, simd_width: Int, //, out_type: DType](val: SIMD[in_type, simd_width]) -> Scalar[out_type]` Performs tensor core based reduction on a SIMD vector. Note: Dispatches to either scalar or vector reduction implementation based on SIMD width. Supports various input/output type combinations using tensor core operations. **Parameters:** * ​in\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The input data type of the SIMD vector elements. * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the SIMD vector. * ​out\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The output data type for the reduced result. **Args:** * ​val ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Input SIMD vector to reduce. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): Scalar containing the reduced result.
--- ## tc_reduce_gevm_4x
`tc_reduce_gevm_4x[out_type: DType, in_type: DType, simd_width: Int](val1: SIMD[in_type, simd_width]) -> SIMD[out_type, simd_width]` Performs a 4x GEVM reduction using tensor cores. Note: Currently only supports bfloat16 input to float32 output conversion. Uses tensor core matrix multiply-accumulate (MMA) operations for reduction. **Parameters:** * ​out\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The output data type for the reduction result (must be float32). * ​in\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The input data type of the vector to reduce (must be bfloat16). * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the SIMD vector. **Args:** * ​val1 ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Input SIMD vector to reduce. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing the reduced result.
--- ## tc_reduce_gevm_8x
`tc_reduce_gevm_8x[out_type: DType, in_type: DType, simd_width: Int](val1: SIMD[in_type, simd_width], val2: SIMD[in_type, simd_width]) -> SIMD[out_type, simd_width]` Performs an 8x GEVM reduction using tensor cores. Note: Currently only supports bfloat16 input to float32 output conversion. Uses tensor core matrix multiply-accumulate (MMA) operations for reduction. **Parameters:** * ​out\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The output data type for the reduction result (must be float32). * ​in\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The input data type of the vectors to reduce (must be bfloat16). * ​simd\_width ([`Int`](/mojo/stdlib/builtin/int/Int)): The width of the SIMD vectors. **Args:** * ​val1 ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): First input SIMD vector to reduce. * ​val2 ([`SIMD`](/mojo/stdlib/builtin/simd/SIMD)): Second input SIMD vector to reduce. **Returns:** [`SIMD`](/mojo/stdlib/builtin/simd/SIMD): SIMD vector containing the reduced result.
--- ## globals (Globals)
This module provides GPU-specific global constants and configuration values. The module defines hardware-specific constants like warp size and thread block limits that are used throughout the GPU programming interface. It handles both NVIDIA and AMD GPU architectures, automatically detecting and configuring the appropriate values based on the available hardware. The constants are resolved at compile time based on the target GPU architecture and are used to optimize code generation and ensure hardware compatibility. ## `comptime` values ### `MAX_THREADS_PER_BLOCK_METADATA` `comptime MAX_THREADS_PER_BLOCK_METADATA = _resolve_max_threads_per_block_metadata()` This is metadata tag that is used in conjunction with \_\_llvm\_metadata to give a hint to the compiler about the max threads per block that's used. ### `WARP_SIZE` `comptime WARP_SIZE = _resolve_warp_size()` The number of threads that execute in lockstep within a warp on the GPU. This constant represents the hardware warp size, which is the number of threads that execute instructions synchronously as a unit. The value is architecture-dependent: * 32 threads per warp on NVIDIA GPUs * 32 threads per warp on AMD RDNA GPUs * 64 threads per warp on AMD CDNA GPUs * 0 if no GPU is detected The warp size is a fundamental parameter that affects: * Thread scheduling and execution * Memory access coalescing * Synchronization primitives * Overall performance optimization ### `WARPGROUP_SIZE` `comptime WARPGROUP_SIZE = _resolve_warpgroup_size()` The number of threads in a warpgroup on Nvidia GPUs. On Nvidia GPUs after hopper, a warpgroup consists of 4 subsequent arps i.e. 128 threads. The first warp id must be multiple of 4. Warpgroup is used for wgmma instructions on Hopper and tcgen05.ld on Blackwell.
--- ## grid_controls
GPU grid dependency control (deprecated - use `gpu.primitives.grid_controls` or `gpu`). This module is deprecated. For new code, import grid control operations from the `gpu` package or `gpu.primitives.grid_controls` module: ```mojo # Deprecated: from gpu.grid_controls import PDL, PDLLevel, launch_dependent_grids # Recommended (import from top-level gpu package): from gpu import PDL, PDLLevel, launch_dependent_grids # Or import the module: from gpu.primitives import grid_controls ``` This module provides Hopper PDL (Programmable Distributed Launch) operations for controlling grid dependencies on NVIDIA GPUs.
--- ## get_gpu_target
`get_gpu_target[target_arch: StringSlice[StaticConstantOrigin] = _accelerator_arch()]() -> __mlir_type.`!kgen.target\`\` Gets the GPU target information for the specified architecture. **Parameters:** * ​target\_arch ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): GPU architecture name (defaults to current accelerator architecture). **Returns:** `__mlir_type.`!kgen.target\`\`: Target type information for the specified GPU architecture.
--- ## compile (3)
Implements CUDA compilation operations. ## Functions * [​`get_gpu_target`](/mojo/stdlib/gpu/host/compile/get_gpu_target): Gets the GPU target information for the specified architecture.
--- ## ConstantMemoryMapping
`@register_passable(trivial)` `struct ConstantMemoryMapping` Represents a mapping of constant memory between host and device. This struct encapsulates the information needed to manage constant memory that can be accessed by GPU kernels. Constant memory provides a fast, read-only cache accessible by all threads on the GPU device. Attributes: name: A string identifier for the constant memory mapping. ptr: Pointer to the memory location. byte\_count: Size of the memory mapping in bytes. ## Fields * ​name (`StaticString`): A string identifier for the constant memory mapping. This name is used to uniquely identify the constant memory region in the GPU programming model, allowing the runtime to properly associate the memory with kernel references to constant memory symbols. * ​ptr (`LegacyOpaquePointer`): Pointer to the host memory location that will be mapped to device constant memory. This raw pointer represents the starting address of the memory region that will be accessible as constant memory on the GPU. The memory should remain valid for the lifetime of any kernels that access it. * ​byte\_count (`Int`): Size of the memory mapping in bytes. Specifies the total size of the constant memory region. This value is used by the runtime to determine how much data to transfer between host and device. The size must be sufficient to hold all data needed by GPU kernels. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True`
--- ## constant_memory_mapping
This module provides functionality for mapping constant memory between host and device. The module includes the `ConstantMemoryMapping` struct which represents a mapping of constant memory that can be used for efficient data transfer between host and GPU device. ## Structs * [​`ConstantMemoryMapping`](/mojo/stdlib/gpu/host/constant_memory_mapping/ConstantMemoryMapping): Represents a mapping of constant memory between host and device.
--- ## DeviceAttribute
`@register_passable(trivial)` `struct DeviceAttribute` Represents CUDA device attributes that can be queried from a GPU device. This struct encapsulates the various device properties and capabilities that can be queried through the CUDA driver API. Each attribute is represented as a constant with a corresponding integer value that maps to the CUDA driver's attribute enum. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `CLOCK_RATE` `comptime CLOCK_RATE = DeviceAttribute(13)` Typical clock frequency in kilohertz. ### `COMPUTE_CAPABILITY_MAJOR` `comptime COMPUTE_CAPABILITY_MAJOR = DeviceAttribute(75)` Major compute capability version number. ### `COMPUTE_CAPABILITY_MINOR` `comptime COMPUTE_CAPABILITY_MINOR = DeviceAttribute(76)` Minor compute capability version number. ### `COOPERATIVE_LAUNCH` `comptime COOPERATIVE_LAUNCH = DeviceAttribute(95)` Device supports launching cooperative kernels. ### `MAX_ACCESS_POLICY_WINDOW_SIZE` `comptime MAX_ACCESS_POLICY_WINDOW_SIZE = DeviceAttribute(109)` CUDA-only: Maximum value of CUaccessPolicyWindow::num\_bytes. ### `MAX_BLOCK_DIM_X` `comptime MAX_BLOCK_DIM_X = DeviceAttribute(2)` Maximum block dimension X. ### `MAX_BLOCK_DIM_Y` `comptime MAX_BLOCK_DIM_Y = DeviceAttribute(3)` Maximum block dimension Y. ### `MAX_BLOCK_DIM_Z` `comptime MAX_BLOCK_DIM_Z = DeviceAttribute(4)` Maximum block dimension Z. ### `MAX_BLOCKS_PER_MULTIPROCESSOR` `comptime MAX_BLOCKS_PER_MULTIPROCESSOR = DeviceAttribute(106)` Maximum resident blocks per multiprocessor. ### `MAX_GRID_DIM_X` `comptime MAX_GRID_DIM_X = DeviceAttribute(5)` Maximum grid dimension X. ### `MAX_GRID_DIM_Y` `comptime MAX_GRID_DIM_Y = DeviceAttribute(6)` Maximum grid dimension Y. ### `MAX_GRID_DIM_Z` `comptime MAX_GRID_DIM_Z = DeviceAttribute(7)` Maximum grid dimension Z. ### `MAX_REGISTERS_PER_BLOCK` `comptime MAX_REGISTERS_PER_BLOCK = DeviceAttribute(12)` Maximum number of 32-bit registers available per block. ### `MAX_REGISTERS_PER_MULTIPROCESSOR` `comptime MAX_REGISTERS_PER_MULTIPROCESSOR = DeviceAttribute(82)` Maximum number of 32-bit registers available per multiprocessor. ### `MAX_SHARED_MEMORY_PER_BLOCK` `comptime MAX_SHARED_MEMORY_PER_BLOCK = DeviceAttribute(8)` Maximum shared memory available per block in bytes. ### `MAX_SHARED_MEMORY_PER_BLOCK_OPTIN` `comptime MAX_SHARED_MEMORY_PER_BLOCK_OPTIN = DeviceAttribute(97)` Maximum shared memory per block usable via `cudaFuncSetAttribute`. ### `MAX_SHARED_MEMORY_PER_MULTIPROCESSOR` `comptime MAX_SHARED_MEMORY_PER_MULTIPROCESSOR = DeviceAttribute(81)` Maximum shared memory available per multiprocessor in bytes. ### `MAX_THREADS_PER_BLOCK` `comptime MAX_THREADS_PER_BLOCK = DeviceAttribute(1)` Maximum number of threads per block. ### `MAX_THREADS_PER_MULTIPROCESSOR` `comptime MAX_THREADS_PER_MULTIPROCESSOR = DeviceAttribute(39)` Maximum resident threads per multiprocessor. ### `MULTIPROCESSOR_COUNT` `comptime MULTIPROCESSOR_COUNT = DeviceAttribute(16)` Number of multiprocessors on device. ### `WARP_SIZE` `comptime WARP_SIZE = DeviceAttribute(10)` Warp size in threads.
--- ## device_attribute
This module defines GPU device attributes that can be queried from CUDA-compatible devices. The module provides the `DeviceAttribute` struct which encapsulates the various device properties and capabilities that can be queried through the CUDA driver API. Each attribute is represented as a constant with a corresponding integer value that maps to the CUDA driver's attribute enumeration. These attributes allow applications to query specific hardware capabilities and limitations of GPU devices, such as maximum thread counts, memory sizes, compute capabilities, and supported features. :::note See the [`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext/) page for examples that retrieve `DeviceAttribute` values. ::: ## Structs * [​`DeviceAttribute`](/mojo/stdlib/gpu/host/device_attribute/DeviceAttribute): Represents CUDA device attributes that can be queried from a GPU device.
--- ## DeviceBuffer
`struct DeviceBuffer[dtype: DType]` Represents a block of device-resident storage. For GPU devices, a device buffer is allocated in the device's global memory. To allocate a `DeviceBuffer`, use one of the methods provided by `DeviceContext`, such as [`enqueue_create_buffer()`](/mojo/stdlib/gpu/host/device_context/DeviceContext#enqueue_create_buffer). ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data dtype to be stored in the buffer. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `device_type` `comptime device_type = LegacyUnsafePointer[Scalar[dtype]]` DeviceBuffer dtypes are remapped to UnsafePointer when passed to accelerator devices. ## Methods ### `__copyinit__` `__copyinit__(out self, existing: Self)` Creates a copy of an existing device buffer by incrementing its reference count. This copy constructor creates a new reference to the same underlying device buffer by incrementing the reference count of the native buffer object. Both the original and the copy will refer to the same memory on the device. **Args:** * ​existing (`Self`): The device buffer to copy. ### `__del__` `__del__(deinit self)` Releases resources associated with this device buffer. This function schedules an owned buffer free using the stream in the device context. The actual deallocation may occur asynchronously after all operations using this buffer have completed. ### `get_type_name` `static get_type_name() -> String` Gets this dtype's name, for use in error messages when handing arguments to kernels. TODO: This will go away soon, when we get better error messages for kernel calls. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): This dtype's name. ### `get_device_type_name` `static get_device_type_name() -> String` Gets device\_type's name, for use in error messages when handing arguments to kernels. TODO: This will go away soon, when we get better error messages for kernel calls. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): This dtype's name. ### `__len__` `__len__(self) -> Int` Returns the number of elements in this buffer. This method calculates the number of elements by dividing the total byte size of the buffer by the size of each element. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the buffer. ### `create_sub_buffer` `create_sub_buffer[view_type: DType](self, offset: Int, size: Int) -> DeviceBuffer[view_type]` Creates a sub-buffer view of this buffer with a different element dtype. This method creates a new buffer that references a subset of the memory in this buffer, potentially with a different element dtype. The sub-buffer shares the underlying memory with the original buffer. **Parameters:** * ​view\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for elements in the new sub-buffer. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The starting offset in elements from the beginning of this buffer. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements in the new sub-buffer. **Returns:** [`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer): A new DeviceBuffer referencing the specified region with the specified element dtype. **Raises:** If the operation fails. ### `enqueue_copy_to` `enqueue_copy_to(self, dst: Self)` Enqueues an asynchronous copy from this buffer to another device buffer. This method schedules a memory copy operation from this buffer to the destination buffer. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​dst (`Self`): The destination device buffer to copy data to. **Raises:** If the operation fails. `enqueue_copy_to(self, dst: HostBuffer[dtype])` Enqueues an asynchronous copy from this buffer to a host buffer. This method schedules a memory copy operation from this buffer to the destination buffer. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​dst ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): The destination host buffer to copy data to. **Raises:** If the operation fails. `enqueue_copy_to(self, dst_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin])` Enqueues an asynchronous copy from this buffer to host memory. This method schedules a memory copy operation from this device buffer to the specified host memory location. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​dst\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to the destination host memory location. **Raises:** If the operation fails. ### `enqueue_copy_from` `enqueue_copy_from(self, src: Self)` Enqueues an asynchronous copy to this buffer from another device buffer. This method schedules a memory copy operation to this buffer from the source buffer. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​src (`Self`): The source device buffer to copy data from. **Raises:** If the operation fails. `enqueue_copy_from(self, src: HostBuffer[dtype])` Enqueues an asynchronous copy to this buffer from a host buffer. This method schedules a memory copy operation to this buffer from the source buffer. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​src ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): The source host buffer to copy data from. **Raises:** If the operation fails. `enqueue_copy_from(self, src_ptr: UnsafePointer[Scalar[dtype], origin])` Enqueues an asynchronous copy to this buffer from host memory. This method schedules a memory copy operation to this device buffer from the specified host memory location. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​src\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to the source host memory location. **Raises:** If the operation fails. ### `enqueue_fill` `enqueue_fill(self, val: Scalar[dtype])` Enqueues an operation to fill this buffer with a specified value. This method schedules a memory set operation that fills the entire buffer with the specified value. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value to fill the buffer with. **Raises:** If the operation fails. ### `reassign_ownership_to` `reassign_ownership_to(self, ctx: DeviceContext)` Transfers ownership of this buffer to another device context. This method changes the device context that owns this buffer. This can be useful when sharing buffers between different contexts or when migrating workloads between devices. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The new device context to take ownership of this buffer. **Raises:** If the operation fails. ### `take_ptr` `take_ptr(var self) -> UnsafePointer[Scalar[dtype], MutAnyOrigin]` Takes ownership of the device pointer from this buffer. This method releases the device pointer from the buffer's control and returns it to the caller. After this call, the buffer no longer owns the pointer, and the caller is responsible for managing its lifecycle. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The raw device pointer that was owned by this buffer. ### `unsafe_ptr` `unsafe_ptr(self) -> UnsafePointer[Scalar[dtype], MutAnyOrigin]` Returns the raw device pointer without transferring ownership. This method provides direct access to the underlying device pointer for advanced use cases. The buffer retains ownership of the pointer. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The raw device pointer owned by this buffer. ### `context` `context(self) -> DeviceContext` Returns the device context associated with this buffer. This method retrieves the device context that owns this buffer and is responsible for managing its lifecycle and operations. **Returns:** [`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext): The device context associated with this buffer. **Raises:** If the operation fails. ### `map_to_host` `map_to_host(self, out mapped_buffer: _HostMappedBuffer[dtype])` Maps this device buffer to host memory for CPU access. This method creates a host-accessible view of the device buffer's contents. The mapping operation may involve copying data from device to host memory. Notes: Values modified inside the `with` statement are updated on the device when the `with` statement exits. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() var length = 1024 var in_dev = ctx.enqueue_create_buffer[DType.float32](length) var out_dev = ctx.enqueue_create_buffer[DType.float32](length) # Initialize the input and output with known values. with in_dev.map_to_host() as in_host, out_dev.map_to_host() as out_host: for i in range(length): in_host[i] = i out_host[i] = 255 ``` **Returns:** [`_HostMappedBuffer`](/mojo/stdlib/gpu/host/device_context/_HostMappedBuffer): A host-mapped buffer that provides CPU access to the device buffer's contents inside a with-statement. **Raises:** If there's an error during buffer creation or data transfer. ### `write_to` `write_to(self, mut writer: T)` Writes a string representation of this buffer to the provided writer. This method formats the buffer's contents as a string and writes it to the specified writer. For large buffers, a compact representation is used. **Args:** * ​writer (`T`): The writer to output the formatted string to. ### `__str__` `__str__(self) -> String` Returns a string representation of the `DeviceBuffer`. This method creates a human-readable string representation of the buffer's contents by mapping the device memory to host memory and formatting the elements. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the formatted buffer contents.
--- ## DeviceContext
`@register_passable` `struct DeviceContext` Represents a single stream of execution on a particular accelerator (GPU). A `DeviceContext` serves as the low-level interface to the accelerator inside a MAX [custom operation](/max/develop/custom-ops/) and provides methods for allocating buffers on the device, copying data between host and device, and for compiling and running functions (also known as kernels) on the device. The device context can be used as a [context manager](/mojo/manual/errors#use-a-context-manager). For example: ```mojo from gpu.host import DeviceContext from gpu import thread_idx fn kernel(): print("hello from thread:", thread_idx.x, thread_idx.y, thread_idx.z) with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=(2, 2, 2)) ctx.synchronize() ``` A custom operation receives an opaque `DeviceContextPtr`, which provides a `get_device_context()` method to retrieve the device context: ```mojo from runtime.asyncrt import DeviceContextPtr @register("custom_op") struct CustomOp: @staticmethod fn execute(ctx_ptr: DeviceContextPtr) raises: var ctx = ctx_ptr.get_device_context() ctx.enqueue_function[kernel](grid_dim=1, block_dim=(2, 2, 2)) ctx.synchronize() ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `default_device_info` `comptime default_device_info = GPUInfo.from_name[_accelerator_arch()]()` `GPUInfo` object for the default accelerator. ## Methods ### `__init__` `__init__(out self, device_id: Int = 0, *, var api: String = DeviceContext.default_device_info.api)` Constructs a `DeviceContext` for the specified device. This initializer creates a new device context for the specified accelerator device. The device context provides an interface for interacting with the GPU, including memory allocation, data transfer, and kernel execution. Example: ```mojo from gpu.host import DeviceContext # Create a context for the default GPU var ctx = DeviceContext() # Create a context for a specific GPU (device 1) var ctx2 = DeviceContext(1) ``` **Args:** * ​device\_id ([`Int`](/mojo/stdlib/builtin/int/Int)): ID of the accelerator device. If not specified, uses the default accelerator (device 0). * ​api ([`String`](/mojo/stdlib/collections/string/string/String)): Requested device API (for example, "cuda" or "hip"). Defaults to the device API specified by current target accelerator. **Raises:** If device initialization fails or the specified device is not available. ### `__copyinit__` `__copyinit__(existing: Self) -> Self` Creates a copy of an existing device context by incrementing its reference count. This copy constructor creates a new reference to the same underlying device context by incrementing the reference count of the native context object. Both the original and the copy will refer to the same device context. **Args:** * ​existing (`Self`): The device context to copy. ### `__del__` `__del__(deinit self)` Releases resources associated with this device context. This destructor decrements the reference count of the native device context. When the reference count reaches zero, the underlying resources are released, including any cached memory buffers and compiled device functions. ### `__enter__` `__enter__(var self) -> Self` Enables the use of DeviceContext in a 'with' statement context manager. This method allows DeviceContext to be used with Python-style context managers, which ensures proper resource management and cleanup when the context exits. Example: ```mojo from gpu.host import DeviceContext # Using DeviceContext as a context manager with DeviceContext() as ctx: # Perform GPU operations # Resources are automatically released when exiting the block ``` **Returns:** `Self`: The DeviceContext instance to be used within the context manager block. ### `name` `name(self) -> String` Returns the device name, an ASCII string identifying this device, defined by the native device API. This method queries the underlying GPU device for its name, which typically includes the model and other identifying information. This can be useful for logging, debugging, or making runtime decisions based on the specific GPU hardware. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() print("Running on device:", ctx.name()) ``` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the device name. ### `api` `api(self) -> String` Returns the name of the API used to program the device. This method queries the underlying device context to determine which GPU programming API is being used for the current device. This information is useful for writing code that can adapt to different GPU architectures and programming models. Possible values are: * "cpu": Generic host device (CPU). * "cuda": NVIDIA GPUs. * "hip": AMD GPUs. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() var api_name = ctx.api() print("Using device API:", api_name) # Conditionally execute code based on the API if api_name == "cuda": print("Running on NVIDIA GPU") elif api_name == "hip": print("Running on AMD GPU") ``` **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string identifying the device API. ### `enqueue_create_buffer` `enqueue_create_buffer[dtype: DType](self, size: Int) -> DeviceBuffer[dtype]` Enqueues a buffer creation using the `DeviceBuffer` constructor. For GPU devices, the space is allocated in the device's global memory. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type to be stored in the allocated memory. **Args:** * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements of `type` to allocate memory for. **Returns:** [`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer): The allocated buffer. **Raises:** If the operation fails. ### `create_buffer_sync` `create_buffer_sync[dtype: DType](self, size: Int) -> DeviceBuffer[dtype]` Creates a buffer synchronously using the `DeviceBuffer` constructor. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type to be stored in the allocated memory. **Args:** * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements of `type` to allocate memory for. **Returns:** [`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer): The allocated buffer. **Raises:** If the operation fails. ### `enqueue_create_host_buffer` `enqueue_create_host_buffer[dtype: DType](self, size: Int) -> HostBuffer[dtype]` Enqueues the creation of a HostBuffer. This function allocates memory on the host that is accessible by the device. The memory is page-locked (pinned) for efficient data transfer between host and device. Pinned memory is guaranteed to remain resident in the host's RAM, not be paged/swapped out to disk. Memory allocated normally (for example, using [`alloc()`](/mojo/stdlib/memory/unsafe_pointer/alloc)) is pageable—individual pages of memory can be moved to secondary storage (disk/SSD) when main memory fills up. Using pinned memory allows devices to make fast transfers between host memory and device memory, because they can use direct memory access (DMA) to transfer data without relying on the CPU. Allocating too much pinned memory can cause performance issues, since it reduces the amount of memory available for other processes. Example: ```mojo from gpu.host import DeviceContext with DeviceContext() as ctx: # Allocate host memory accessible by the device var host_buffer = ctx.enqueue_create_host_buffer[DType.float32](1024) # Use the host buffer for device operations # ... ``` **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type to be stored in the allocated memory. **Args:** * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements of `type` to allocate memory for. **Returns:** [`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer): A `HostBuffer` object that wraps the allocated host memory. **Raises:** If memory allocation fails or if the device context is invalid. ### `compile_function` `compile_function[func_type: AnyTrivialRegType, //, func: func_type, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[DeviceContext.default_device_info.target()](), _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *, func_attribute: OptionalReg[FuncAttribute] = None, out result: DeviceFunction[func, None, target=DeviceContext.default_device_info.target(), compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose])` Compiles the provided function for execution on this device. **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the function. * ​func (`func_type`): The function to compile. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Change the compile options to different options than the ones associated with this `DeviceContext`. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): An attribute to use when compiling the code (such as maximum shared memory size). **Returns:** [`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction): The compiled function. **Raises:** If the operation fails. ### `compile_function_unchecked` `compile_function_unchecked[func_type: AnyTrivialRegType, //, func: func_type, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[DeviceContext.default_device_info.target()](), _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *, func_attribute: OptionalReg[FuncAttribute] = None, out result: DeviceFunction[func, None, target=DeviceContext.default_device_info.target(), compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose])` Compiles the provided function for execution on this device. **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the function. * ​func (`func_type`): The function to compile. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Change the compile options to different options than the ones associated with this `DeviceContext`. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): An attribute to use when compiling the code (such as maximum shared memory size). **Returns:** [`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction): The compiled function via the `result` output parameter. **Raises:** If the operation fails. ### `compile_function_checked` `compile_function_checked[func_type: AnyTrivialRegType, declared_arg_types: Variadic[AnyType], //, func: func_type, signature_func: fn(*args: *declared_arg_types) -> None, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[DeviceContext.default_device_info.target()](), _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *, func_attribute: OptionalReg[FuncAttribute] = None, out result: DeviceFunction[func, declared_arg_types, compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose])` Compiles the provided function for execution on this device. **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the function. * ​declared\_arg\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): Types of the arguments to pass to the device function. * ​func (`func_type`): The function to compile. * ​signature\_func (`fn(*args: *declared_arg_types) -> None`): The function to compile, passed in again. Used for checking argument dtypes later. Note: This will disappear in future versions. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Change the compile options to different options than the ones associated with this `DeviceContext`. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): An attribute to use when compiling the code (such as maximum shared memory size). **Returns:** [`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction): The compiled function via the `result` output parameter. **Raises:** If the operation fails. `compile_function_checked[func_type: AnyTrivialRegType, declared_arg_types: Variadic[AnyType], //, func: func_type, signature_func: fn(*args: *declared_arg_types) capturing -> None, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[DeviceContext.default_device_info.target()](), _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *, func_attribute: OptionalReg[FuncAttribute] = None, out result: DeviceFunction[func, declared_arg_types, target=DeviceContext.default_device_info.target(), compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose])` Compiles the provided function for execution on this device. **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Type of the function. * ​declared\_arg\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): Types of the arguments to pass to the device function. * ​func (`func_type`): The function to compile. * ​signature\_func (`fn(*args: *declared_arg_types) capturing -> None`): The function to compile, passed in again. Used for checking argument dtypes later. Note: This will disappear in future versions. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Change the compile options to different options than the ones associated with this `DeviceContext`. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): An attribute to use when compiling the code (such as maximum shared memory size). **Returns:** [`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction): The compiled function via the `result` output parameter. **Raises:** If the operation fails. ### `compile_function_experimental` `compile_function_experimental[declared_arg_types: Variadic[AnyType], //, func: fn(*args: *declared_arg_types) -> None, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[DeviceContext.default_device_info.target()](), _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *, func_attribute: OptionalReg[FuncAttribute] = None, out result: DeviceFunction[func, declared_arg_types, target=DeviceContext.default_device_info.target(), compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose])` Compiles the provided function for execution on this device. **Parameters:** * ​declared\_arg\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): Types of the arguments to pass to the device function. * ​func (`fn(*args: *declared_arg_types) -> None`): The function to compile. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Change the compile options to different options than the ones associated with this `DeviceContext`. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): An attribute to use when compiling the code (such as maximum shared memory size). **Returns:** [`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction): The compiled function via the `result` output parameter. **Raises:** If the operation fails. `compile_function_experimental[declared_arg_types: Variadic[AnyType], //, func: fn(*args: *declared_arg_types) capturing -> None, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[DeviceContext.default_device_info.target()](), _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *, func_attribute: OptionalReg[FuncAttribute] = None, out result: DeviceFunction[func, declared_arg_types, target=DeviceContext.default_device_info.target(), compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose])` Compiles the provided function for execution on this device. **Parameters:** * ​declared\_arg\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): Types of the arguments to pass to the device function. * ​func (`fn(*args: *declared_arg_types) capturing -> None`): The function to compile. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Change the compile options to different options than the ones associated with this `DeviceContext`. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): An attribute to use when compiling the code (such as maximum shared memory size). **Returns:** [`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction): The compiled function via the `result` output parameter. **Raises:** If the operation fails. ### `load_function` `load_function[func_type: AnyTrivialRegType, //, func: func_type](self, *, function_name: StringSlice[origin], asm: StringSlice[origin], func_attribute: OptionalReg[FuncAttribute] = None, out result: DeviceExternalFunction)` Loads a pre-compiled device function from assembly code. This method loads an external GPU function from provided assembly code (PTX/SASS) rather than compiling it from Mojo source. This is useful for integrating with existing CUDA/HIP code or for using specialized assembly optimizations. Example: ```mojo from gpu.host import DeviceContext from gpu.host.device_context import DeviceExternalFunction fn func_signature( # Arguments being passed to the assembly code # e.g. two pointers and a length input: UnsafePointer[Float32], output: UnsafePointer[Float32], len: Int, ): # No body because that is passed as assembly code below. pass var ctx = DeviceContext() var ptx_code = "..." # PTX assembly code var ext_func = ctx.load_function[func_signature]( function_name="my_kernel", asm=ptx_code, ) ``` **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The dtype of the function to load. * ​func (`func_type`): The function reference. **Args:** * ​function\_name ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The name of the function in the assembly code. * ​asm ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The assembly code (PTX/SASS) containing the function. * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Optional attribute to apply to the function (such as maximum shared memory size). **Returns:** [`DeviceExternalFunction`](/mojo/stdlib/gpu/host/device_context/DeviceExternalFunction): The loaded function is stored in the `result` parameter. **Raises:** If loading the function fails or the assembly code is invalid. ### `enqueue_function` `enqueue_function[func_type: AnyTrivialRegType, //, func: func_type, *Ts: AnyType, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[DeviceContext.default_device_info.target()](), _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), func_attribute: OptionalReg[FuncAttribute] = None, location: OptionalReg[_SourceLocation] = None)` Compiles and enqueues a kernel for execution on this device. You can pass the function directly to `enqueue_function` without compiling it first: ```mojo from gpu.host import DeviceContext fn kernel(): print("hello from the GPU") with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=1) ctx.synchronize() ``` If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile it first to remove the overhead: ```mojo with DeviceContext() as ctx: var compile_func = ctx.compile_function_checked[kernel, kernel]() ctx.enqueue_function_checked(compile_func, grid_dim=1, block_dim=1) ctx.enqueue_function_checked(compile_func, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The dtype of the function to launch. * ​func (`func_type`): The function to launch. * ​\*Ts ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The dtypes of the arguments being passed to the function. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Change the compile options to different options than the ones associated with this `DeviceContext`. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​\*args (`*Ts`): Variadic arguments which are passed to the `func`. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The grid dimensions. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The block dimensions. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The cluster dimensions. * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Per-block memory shared between blocks. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): A `List` of launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): A `List` of constant memory mappings. * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): `CUfunction_attribute` enum. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. `enqueue_function[*Ts: AnyType](self, f: DeviceFunction[func, declared_arg_types, target=target, compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose], *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), location: OptionalReg[_SourceLocation] = None)` Enqueues a compiled function for execution on this device. You can pass the function directly to `enqueue_function` without compiling it first: ```mojo from gpu.host import DeviceContext fn kernel(): print("hello from the GPU") with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=1) ctx.synchronize() ``` If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile the function first to remove the overhead: ```mojo from gpu.host import DeviceContext with DeviceContext() as ctx: var compiled_func = ctx.compile_function_checked[kernel, kernel]() ctx.enqueue_function_checked(compiled_func, grid_dim=1, block_dim=1) ctx.enqueue_function_checked(compiled_func, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​\*Ts ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Argument dtypes. **Args:** * ​f ([`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction)): The compiled function to execute. * ​\*args (`*Ts`): Arguments to pass to the function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of the compute grid, made up of thread blocks. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of each thread block in the grid. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Dimensions of clusters (if the thread blocks are grouped into clusters). * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Amount of shared memory per thread block. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): Launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): Constant memory mapping. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. ### `enqueue_function_unchecked` `enqueue_function_unchecked[func_type: AnyTrivialRegType, //, func: func_type, *Ts: AnyType, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), func_attribute: OptionalReg[FuncAttribute] = None, location: OptionalReg[_SourceLocation] = None)` Compiles and enqueues a kernel for execution on this device. You can pass the function directly to `enqueue_function` without compiling it first: ```mojo from gpu.host import DeviceContext fn kernel(): print("hello from the GPU") with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=1) ctx.synchronize() ``` If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile it first to remove the overhead: ```mojo with DeviceContext() as ctx: var compile_func = ctx.compile_function_checked[kernel, kernel]() ctx.enqueue_function_checked(compile_func, grid_dim=1, block_dim=1) ctx.enqueue_function_checked(compile_func, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The dtype of the function to launch. * ​func (`func_type`): The function to launch. * ​\*Ts ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): The dtypes of the arguments being passed to the function. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​\*args (`*Ts`): Variadic arguments which are passed to the `func`. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The grid dimensions. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The block dimensions. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The cluster dimensions. * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Per-block memory shared between blocks. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): A `List` of launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): A `List` of constant memory mappings. * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): `CUfunction_attribute` enum. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. `enqueue_function_unchecked[*Ts: AnyType](self, f: DeviceFunction[func, declared_arg_types, target=target, compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose], *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), location: OptionalReg[_SourceLocation] = None)` Enqueues a compiled function for execution on this device. You can pass the function directly to `enqueue_function` without compiling it first: ```mojo from gpu.host import DeviceContext fn kernel(): print("hello from the GPU") with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=1) ctx.synchronize() ``` If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile the function first to remove the overhead: ```mojo from gpu.host import DeviceContext with DeviceContext() as ctx: var compiled_func = ctx.compile_function_checked[kernel, kernel]() ctx.enqueue_function_checked(compiled_func, grid_dim=1, block_dim=1) ctx.enqueue_function_checked(compiled_func, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​\*Ts ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Argument dtypes. **Args:** * ​f ([`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction)): The compiled function to execute. * ​\*args (`*Ts`): Arguments to pass to the function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of the compute grid, made up of thread blocks. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of each thread block in the grid. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Dimensions of clusters (if the thread blocks are grouped into clusters). * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Amount of shared memory per thread block. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): Launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): Constant memory mapping. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. ### `enqueue_function_checked` `enqueue_function_checked[*Ts: DevicePassable](self, f: DeviceFunction[func, declared_arg_types, target=target, compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose], *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), location: OptionalReg[_SourceLocation] = None)` Enqueues a pre-compiled checked function for execution on this device. This overload requires a `DeviceFunction` that was compiled with type checking enabled (via `compile_function_checked`). The function will verify that the argument types match the declared types at compile time. ```mojo from gpu.host import DeviceContext fn kernel(x: Int): print("Value:", x) with DeviceContext() as ctx: ctx.enqueue_function_checked[kernel, kernel](compiled_func, 42, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​\*Ts ([`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable)): Argument dtypes. **Args:** * ​f ([`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction)): The compiled function to execute. * ​\*args (`*Ts`): Arguments to pass to the function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of the compute grid, made up of thread blocks. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of each thread block in the grid. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Dimensions of clusters (if the thread blocks are grouped into clusters). * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Amount of shared memory per thread block. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): Launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): Constant memory mapping. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. `enqueue_function_checked[*Ts: AnyType](self, f: DeviceExternalFunction, *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), location: OptionalReg[_SourceLocation] = None)` Enqueues an external device function for execution on this device. This overload accepts a `DeviceExternalFunction` that was loaded from assembly code (PTX/SASS). External functions are pre-compiled GPU kernels that can be integrated with Mojo code. Example: ```mojo from gpu.host import DeviceContext fn vec_add_sig( in0: UnsafePointer[Float32], in1: UnsafePointer[Float32], out: UnsafePointer[Float32], len: Int, ): pass with DeviceContext() as ctx: var func = ctx.load_function[vec_add_sig]( function_name="vectorAdd", asm=ptx_code, ) ctx.enqueue_function_checked( func, in0_buf, in1_buf, out_buf, 1024, grid_dim=Dim(32), block_dim=Dim(32), ) ctx.synchronize() ``` **Parameters:** * ​\*Ts ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Argument types to pass to the external function. **Args:** * ​f ([`DeviceExternalFunction`](/mojo/stdlib/gpu/host/device_context/DeviceExternalFunction)): The external device function to execute. * ​\*args (`*Ts`): Arguments to pass to the function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of the compute grid, made up of thread blocks. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of each thread block in the grid. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Dimensions of clusters (if the thread blocks are grouped into clusters). * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Amount of shared memory per thread block. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): Launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): Constant memory mapping. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. `enqueue_function_checked[func_type: AnyTrivialRegType, declared_arg_types: Variadic[AnyType], //, func: func_type, signature_func: fn(*args: *declared_arg_types) -> None, *actual_arg_types: DevicePassable, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *args: *actual_arg_types, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), func_attribute: OptionalReg[FuncAttribute] = None, location: OptionalReg[_SourceLocation] = None)` Compiles and enqueues a kernel for execution on this device with type checking. This function performs compile-time type checking on the kernel arguments, ensuring that the types passed match the declared signature. Both `func` and `signature_func` should typically be the same kernel function (this redundancy is required for type checking and will be removed in future versions). Most parameters are inferred automatically. In typical usage, you only need to pass the kernel function twice (as both `func` and `signature_func`): ```mojo from gpu.host import DeviceContext from layout import Layout, LayoutTensor fn vector_add( a: LayoutTensor[DType.float32, Layout.row_major(1000), MutAnyOrigin], b: LayoutTensor[DType.float32, Layout.row_major(1000), MutAnyOrigin], c: LayoutTensor[DType.float32, Layout.row_major(1000), MutAnyOrigin], ): # ... kernel implementation ... pass with DeviceContext() as ctx: # Create tensors a, b, c... # Most parameters are inferred automatically: ctx.enqueue_function_checked[vector_add, vector_add]( a, b, c, grid_dim=4, block_dim=256 ) ctx.synchronize() ``` **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The type of the function to launch (usually inferred). * ​declared\_arg\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): The declared argument types from the function signature (usually inferred). * ​func (`func_type`): The kernel function to compile and launch. * ​signature\_func (`fn(*args: *declared_arg_types) -> None`): The kernel function, passed again for type checking. Typically the same as `func`. * ​\*actual\_arg\_types ([`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable)): The types of the arguments being passed (usually inferred). * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​\*args (`*actual_arg_types`): Variadic arguments which are passed to the kernel function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The grid dimensions. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The block dimensions. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The cluster dimensions. * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Per-block memory shared between blocks. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): A `List` of launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): A `List` of constant memory mappings. * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): `CUfunction_attribute` enum. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. `enqueue_function_checked[func_type: AnyTrivialRegType, declared_arg_types: Variadic[AnyType], //, func: func_type, signature_func: fn(*args: *declared_arg_types) capturing -> None, *actual_arg_types: DevicePassable, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *args: *actual_arg_types, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), func_attribute: OptionalReg[FuncAttribute] = None, location: OptionalReg[_SourceLocation] = None)` Compiles and enqueues a capturing kernel for execution on this device with type checking. This overload is for kernels that capture variables from their enclosing scope. The `capturing` annotation on the signature function indicates that the kernel can access variables from the surrounding context. Like the non-capturing overload, both `func` and `signature_func` should typically be the same kernel function. Most parameters are inferred automatically. This overload is selected when your kernel captures variables from its surrounding scope: ```mojo from gpu.host import DeviceContext from layout import Layout, LayoutTensor fn main(): with DeviceContext() as ctx: var scale_factor = 2.0 # This kernel captures 'scale_factor' from the enclosing scope fn scale_kernel(data: LayoutTensor[DType.float32, Layout.row_major(100), MutAnyOrigin]): # Uses captured scale_factor variable pass # Create tensor 'data'... # Most parameters are inferred: ctx.enqueue_function_checked[scale_kernel, scale_kernel]( data, grid_dim=1, block_dim=256 ) ctx.synchronize() ``` **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The type of the function to launch (usually inferred). * ​declared\_arg\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): The declared argument types from the function signature (usually inferred). * ​func (`func_type`): The capturing kernel function to compile and launch. * ​signature\_func (`fn(*args: *declared_arg_types) capturing -> None`): The kernel function, passed again for type checking. Typically the same as `func`. * ​\*actual\_arg\_types ([`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable)): The types of the arguments being passed (usually inferred). * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​\*args (`*actual_arg_types`): Variadic arguments which are passed to the kernel function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The grid dimensions. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The block dimensions. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The cluster dimensions. * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Per-block memory shared between blocks. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): A `List` of launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): A `List` of constant memory mappings. * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): `CUfunction_attribute` enum. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. ### `enqueue_function_experimental` `enqueue_function_experimental[declared_arg_types: Variadic[AnyType], //, func: fn(*args: *declared_arg_types) -> None, *actual_arg_types: DevicePassable, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *args: *actual_arg_types, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), func_attribute: OptionalReg[FuncAttribute] = None, location: OptionalReg[_SourceLocation] = None)` Compiles and enqueues a kernel for execution on this device. You can pass the function directly to `enqueue_function` without compiling it first: ```mojo from gpu.host import DeviceContext fn kernel(): print("hello from the GPU") with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=1) ctx.synchronize() ``` If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile it first to remove the overhead: ```mojo with DeviceContext() as ctx: var compile_func = ctx.compile_function_checked[kernel, kernel]() ctx.enqueue_function_checked(compile_func, grid_dim=1, block_dim=1) ctx.enqueue_function_checked(compile_func, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​declared\_arg\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): Types of the arguments to pass to the device function. * ​func (`fn(*args: *declared_arg_types) -> None`): The function to compile and launch. * ​\*actual\_arg\_types ([`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable)): The dtypes of the arguments being passed to the function. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​\*args (`*actual_arg_types`): Variadic arguments which are passed to the `func`. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The grid dimensions. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The block dimensions. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The cluster dimensions. * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Per-block memory shared between blocks. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): A `List` of launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): A `List` of constant memory mappings. * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): `CUfunction_attribute` enum. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. `enqueue_function_experimental[declared_arg_types: Variadic[AnyType], //, func: fn(*args: *declared_arg_types) capturing -> None, *actual_arg_types: DevicePassable, *, dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _ptxas_info_verbose: Bool = False](self, *args: *actual_arg_types, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), func_attribute: OptionalReg[FuncAttribute] = None, location: OptionalReg[_SourceLocation] = None)` Compiles and enqueues a kernel for execution on this device. This overload takes in a function that's `capturing`. You can pass the function directly to `enqueue_function` without compiling it first: ```mojo from gpu.host import DeviceContext fn kernel(): print("hello from the GPU") with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=1) ctx.synchronize() ``` If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile it first to remove the overhead: ```mojo with DeviceContext() as ctx: var compile_func = ctx.compile_function_checked[kernel, kernel]() ctx.enqueue_function_checked(compile_func, grid_dim=1, block_dim=1) ctx.enqueue_function_checked(compile_func, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​declared\_arg\_types ([`Variadic`](/mojo/stdlib/builtin/variadics/Variadic)): Types of the arguments to pass to the device function. * ​func (`fn(*args: *declared_arg_types) capturing -> None`): The function to compile and launch. * ​\*actual\_arg\_types ([`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable)): The dtypes of the arguments being passed to the function. * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the compiled assembly, pass `True`, or a file path to dump to, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): To dump the generated LLVM code, pass `True`, or a file path to dump to, or a function returning a file path. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Pass `True`, or a file path to dump to, or a function returning a file path. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Only runs on NVIDIA targets, and requires CUDA Toolkit to be installed. Changes `dump_asm` to output verbose PTX assembly (default `False`). **Args:** * ​\*args (`*actual_arg_types`): Variadic arguments which are passed to the `func`. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The grid dimensions. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): The block dimensions. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): The cluster dimensions. * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Per-block memory shared between blocks. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): A `List` of launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): A `List` of constant memory mappings. * ​func\_attribute ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): `CUfunction_attribute` enum. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. `enqueue_function_experimental[func_type: AnyTrivialRegType, //, func: func_type, declared_arg_types: Optional[Variadic[AnyType]], *Ts: DevicePassable](self, f: DeviceFunction[func, declared_arg_types, target=target, compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose], *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()), location: OptionalReg[_SourceLocation] = None)` Enqueues a compiled function for execution on this device. You can pass the function directly to `enqueue_function` without compiling it first: ```mojo from gpu.host import DeviceContext fn kernel(): print("hello from the GPU") with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=1) ctx.synchronize() ``` If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile the function first to remove the overhead: ```mojo from gpu.host import DeviceContext with DeviceContext() as ctx: var compiled_func = ctx.compile_function_checked[kernel, kernel]() ctx.enqueue_function_checked(compiled_func, grid_dim=1, block_dim=1) ctx.enqueue_function_checked(compiled_func, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): Something. * ​func (`func_type`): Something. * ​declared\_arg\_types ([`Optional`](/mojo/stdlib/collections/optional/Optional)): Something. * ​\*Ts ([`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable)): Argument dtypes. **Args:** * ​f ([`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction)): The compiled function to execute. * ​\*args (`*Ts`): Arguments to pass to the function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of the compute grid, made up of thread blocks. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of each thread block in the grid. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Dimensions of clusters (if the thread blocks are grouped into clusters). * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Amount of shared memory per thread block. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): Launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): Constant memory mapping. * ​location ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Source location for the function call. **Raises:** If the operation fails. ### `execution_time` `execution_time[func: fn(DeviceContext) raises capturing -> None](self, num_iters: Int) -> Int` Measures the execution time of a function that takes a DeviceContext parameter. This method times the execution of a provided function that requires the DeviceContext as a parameter. It runs the function for the specified number of iterations and returns the total elapsed time in nanoseconds. Example: ```mojo from gpu.host import DeviceContext fn gpu_operation(ctx: DeviceContext) raises capturing [_] -> None: # Perform some GPU operation using ctx pass with DeviceContext() as ctx: # Measure execution time of a function that uses the context var time_ns = ctx.execution_time[gpu_operation](10) print("Execution time for 10 iterations:", time_ns, "ns") ``` **Parameters:** * ​func (`fn(DeviceContext) raises capturing -> None`): A function that takes a DeviceContext parameter to execute and time. **Args:** * ​num\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of iterations to run the function. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total elapsed time in nanoseconds for all iterations. **Raises:** If the timer operations fail or if the function raises an exception. `execution_time[func: fn() raises capturing -> None](self, num_iters: Int) -> Int` Measures the execution time of a function over multiple iterations. This method times the execution of a provided function that doesn't require the DeviceContext as a parameter. It runs the function for the specified number of iterations and returns the total elapsed time in nanoseconds. Example: ```mojo from gpu.host import DeviceContext fn some_gpu_operation() raises capturing [_] -> None: # Perform some GPU operation pass with DeviceContext() as ctx: # Measure execution time of a function var time_ns = ctx.execution_time[some_gpu_operation] print("Execution time:", time_ns, "ns") ``` **Parameters:** * ​func (`fn() raises capturing -> None`): A function with no parameters to execute and time. **Args:** * ​num\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of iterations to run the function. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total elapsed time in nanoseconds for all iterations. **Raises:** If the timer operations fail or if the function raises an exception. ### `push_context` `push_context(self) -> _DeviceContextScope` Returns a context manager that ensures this device's driver context is active. This method returns a context manager that pushes this device's driver context as the current context on entry and restores the previous context on exit. This is useful for operations that require a specific GPU context to be active, such as cuDNN operations on multi-GPU systems. Example: ```mojo var ctx = DeviceContext(device_id=1) # Ensure GPU 1's context is active for these operations. with ctx.push_context(): # All GPU operations here will use GPU 1's context. ... # call external stateful APIs, such as cudnn. # Previous context is automatically restored ``` **Returns:** `_DeviceContextScope`: A context manager that manages the driver context stack. **Raises:** If there's an error switching contexts. ### `set_as_current` `set_as_current(self)` For use with libraries that require a specific GPU context to be active. Sets the current device to the one associated with this DeviceContext. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext(device_id=1) ctx.set_as_current() ``` **Raises:** If there's an error setting the current device. ### `execution_time_iter` `execution_time_iter[func: fn(DeviceContext, Int) raises capturing -> None](self, num_iters: Int) -> Int` Measures the execution time of a function that takes iteration index as input. This method times the execution of a provided function that requires both the DeviceContext and the current iteration index as parameters. It runs the function for the specified number of iterations, passing the iteration index to each call, and returns the total elapsed time in nanoseconds. Example: ```mojo from gpu.host import DeviceContext var my_kernel = DeviceFunction(...) fn benchmark_kernel(ctx: DeviceContext, i: Int) raises capturing [_] -> None: # Run kernel with different parameters based on iteration ctx.enqueue_function[my_kernel](grid_dim=Dim(i), block_dim=Dim(256)) with DeviceContext() as ctx: # Measure execution time with iteration awareness var time_ns = ctx.execution_time_iter[benchmark_kernel](10) print("Total execution time:", time_ns, "ns") ``` **Parameters:** * ​func (`fn(DeviceContext, Int) raises capturing -> None`): A function that takes the DeviceContext and an iteration index. **Args:** * ​num\_iters ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of iterations to run the function. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The total elapsed time in nanoseconds for all iterations. **Raises:** If the timer operations fail or if the function raises an exception. ### `enqueue_copy` `enqueue_copy[dtype: DType](self, dst_buf: DeviceBuffer[dtype], src_ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space])` Enqueues an async copy from the host to the provided device buffer. The number of bytes copied is determined by the size of the device buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_buf ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Device buffer to copy to. * ​src\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Host pointer to copy from. **Raises:** If the operation fails. `enqueue_copy[dtype: DType](self, dst_buf: HostBuffer[dtype], src_ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space])` Enqueues an async copy from the host to the provided device buffer. The number of bytes copied is determined by the size of the device buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_buf ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): Device buffer to copy to. * ​src\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Host pointer to copy from. **Raises:** If the operation fails. `enqueue_copy[dtype: DType](self, dst_ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space], src_buf: DeviceBuffer[dtype])` Enqueues an async copy from the device to the host. The number of bytes copied is determined by the size of the device buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Host pointer to copy to. * ​src\_buf ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Device buffer to copy from. **Raises:** If the operation fails. `enqueue_copy[dtype: DType](self, dst_ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space], src_buf: HostBuffer[dtype])` Enqueues an async copy from the device to the host. The number of bytes copied is determined by the size of the device buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Host pointer to copy to. * ​src\_buf ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): Device buffer to copy from. **Raises:** If the operation fails. `enqueue_copy[dtype: DType](self, dst_ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space], src_ptr: UnsafePointer[Scalar[dtype], origin, address_space=address_space], size: Int)` Enqueues an async copy of `size` elements from a device pointer to another device pointer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Host pointer to copy to. * ​src\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Device pointer to copy from. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of elements (of the specified `DType`) to copy. **Raises:** If the operation fails. `enqueue_copy[dtype: DType](self, dst_buf: DeviceBuffer[dtype], src_buf: DeviceBuffer[dtype])` Enqueues an async copy from one device buffer to another. The amount of data transferred is determined by the size of the destination buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_buf ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Device buffer to copy to. * ​src\_buf ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Device buffer to copy from. Must be at least as large as `dst`. **Raises:** If the operation fails. `enqueue_copy[dtype: DType](self, dst_buf: DeviceBuffer[dtype], src_buf: HostBuffer[dtype])` Enqueues an async copy from one device buffer to another. The amount of data transferred is determined by the size of the destination buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_buf ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Device buffer to copy to. * ​src\_buf ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): Device buffer to copy from. Must be at least as large as `dst`. **Raises:** If the operation fails. `enqueue_copy[dtype: DType](self, dst_buf: HostBuffer[dtype], src_buf: DeviceBuffer[dtype])` Enqueues an async copy from one device buffer to another. The amount of data transferred is determined by the size of the destination buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_buf ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): Device buffer to copy to. * ​src\_buf ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Device buffer to copy from. Must be at least as large as `dst`. **Raises:** If the operation fails. `enqueue_copy[dtype: DType](self, dst_buf: HostBuffer[dtype], src_buf: HostBuffer[dtype])` Enqueues an async copy from one device buffer to another. The amount of data transferred is determined by the size of the destination buffer. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data being copied. **Args:** * ​dst\_buf ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): Device buffer to copy to. * ​src\_buf ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): Device buffer to copy from. Must be at least as large as `dst`. **Raises:** If the operation fails. ### `enqueue_memset` `enqueue_memset[dtype: DType](self, dst: DeviceBuffer[dtype], val: Scalar[dtype])` Enqueues an async memset operation, setting all of the elements in the destination device buffer to the specified value. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data stored in the buffer. **Args:** * ​dst ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): Destination buffer. * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Value to set all elements of `dst` to. **Raises:** If the operation fails. `enqueue_memset[dtype: DType](self, dst: HostBuffer[dtype], val: Scalar[dtype])` Enqueues an async memset operation, setting all of the elements in the destination host buffer to the specified value. **Parameters:** * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Type of the data stored in the buffer. **Args:** * ​dst ([`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer)): Destination buffer. * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): Value to set all elements of `dst` to. **Raises:** If the operation fails. ### `create_event` `create_event[*, blocking_sync: Bool = False, disable_timing: Bool = True, interprocess: Bool = False](self) -> DeviceEvent` Creates a new event for synchronization between streams. Provides the best performance by default, disabling timing and blocking sync. `DeviceContext.execution_time()` provides the functionality required for timing kernels by passing it a closure, and is functionally equivalent to recording start and end events, then calculating the elapsed time. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() var default_stream = ctx.stream() var new_stream = ctx.create_stream() # Create an event var event = ctx.create_event() # Wait for the event in new_stream new_stream.enqueue_wait_for(event) # new_stream can continue default_stream.record_event(event) default_stream.synchronize() ``` **Parameters:** * ​blocking\_sync ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Enable `event.synchronize()` to block until the event has been recorded. Incurs overhead compared to `stream.enqueue_wait_for(event)` (default: False). * ​disable\_timing ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Remove timing overhead (default: True). * ​interprocess ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Enable interprocess synchronization, currently unimplemented. (default: False). **Returns:** [`DeviceEvent`](/mojo/stdlib/gpu/host/device_context/DeviceEvent): A DeviceEvent that can be used for synchronization. **Raises:** If event creation fails. ### `stream_priority_range` `stream_priority_range(self) -> StreamPriorityRange` Returns the range of stream priorities supported by this device context. **Returns:** `StreamPriorityRange`: A StreamPriorityRange object containing the minimum and maximum stream priorities. **Raises:** If the operation fails. ### `create_stream` `create_stream(self, *, blocking: Bool = True) -> DeviceStream` Creates a new stream associated with the given device context. **Args:** * ​blocking ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the stream should be blocking. **Returns:** `DeviceStream`: The newly created device stream. **Raises:** If stream creation fails. `create_stream(self, *, priority: Int, blocking: Bool = True) -> DeviceStream` Creates a new stream associated with the given device context. To create a non-blocking stream with the highest priority, use: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() var priority = ctx.stream_priority_range().largest var stream = ctx.create_stream(priority=priority, blocking=False) ``` **Args:** * ​priority ([`Int`](/mojo/stdlib/builtin/int/Int)): The priority of the stream. * ​blocking ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the stream should be blocking. **Returns:** `DeviceStream`: The newly created device stream with the specified priority. **Raises:** If stream creation fails. ### `synchronize` `synchronize(self)` Blocks until all asynchronous calls on the stream associated with this device context have completed. **Raises:** If the operation fails. This should never be necessary when writing a custom operation. ### `enqueue_wait_for` `enqueue_wait_for(self, other: Self)` Enqueues a wait operation for another device context to complete its work. This method creates a dependency between two device contexts, ensuring that operations in the current context will not begin execution until all previously enqueued operations in the other context have completed. This is useful for synchronizing work across multiple devices or streams. Example: ```mojo from gpu.host import DeviceContext # Create two device contexts var ctx1 = DeviceContext(0) # First GPU var ctx2 = DeviceContext(1) # Second GPU # Enqueue operations on ctx1 # ... # Make ctx2 wait for ctx1 to complete before proceeding ctx2.enqueue_wait_for(ctx1) # Enqueue operations on ctx2 that depend on ctx1's completion # ... ``` **Args:** * ​other (`Self`): The device context whose operations must complete before operations in this context can proceed. **Raises:** If there's an error enqueuing the wait operation or if the operation is not supported by the underlying device API. ### `get_api_version` `get_api_version(self) -> Int` Returns the API version associated with this device. This method retrieves the version number of the GPU driver currently installed on the system for the device associated with this context. The version is returned as an integer that can be used to check compatibility with specific features or to troubleshoot driver-related issues. Example: ```mojo from gpu.host import DeviceContext with DeviceContext() as ctx: # Get the API version var api_version = ctx.get_api_version() print("GPU API version:", api_version) ``` **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): An integer representing the driver version. **Raises:** If the driver version cannot be retrieved or if the device context is invalid. ### `get_attribute` `get_attribute(self, attr: DeviceAttribute) -> Int` Returns the specified attribute for this device. Use the aliases defined by [DeviceAttribute](/mojo/stdlib/gpu/host/device_attribute/DeviceAttribute) to specify attributes. For example: ```mojo from gpu.host import DeviceAttribute, DeviceContext def main(): var ctx = DeviceContext() var attr = DeviceAttribute.MAX_BLOCKS_PER_MULTIPROCESSOR var max_blocks = ctx.get_attribute(attr) print(max_blocks) ``` **Args:** * ​attr ([`DeviceAttribute`](/mojo/stdlib/gpu/host/device_attribute/DeviceAttribute)): The device attribute to query. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value for `attr` on this device. **Raises:** If the operation fails. ### `is_compatible` `is_compatible(self) -> Bool` Returns True if this device is compatible with MAX. This method checks whether the current device is compatible with the Modular Accelerated Execution (MAX) runtime. It's useful for validating that the device can execute the compiled code before attempting operations. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() print("Device is compatible with MAX:", ctx.is_compatible()) ``` **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the device is compatible with MAX, False otherwise. ### `id` `id(self) -> Int64` Returns the ID associated with this device. This method retrieves the unique identifier for the current device. Device IDs are used to distinguish between multiple devices in a system and are often needed for multi-GPU programming. Example: ```mojo var ctx = DeviceContext() try: var device_id = ctx.id() print("Using device with ID:", device_id) except: print("Failed to get device ID") ``` **Returns:** [`Int64`](/mojo/stdlib/builtin/simd/#int64): The unique device ID as an Int64. **Raises:** If there's an error retrieving the device ID. ### `get_memory_info` `get_memory_info(self) -> Tuple[UInt, UInt]` Returns the free and total memory size for this device. This method queries the current state of device memory, providing information about how much memory is available and the total memory capacity of the device. This is useful for memory management and determining if there's enough space for planned operations. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() try: (free, total) = ctx.get_memory_info() print("Free memory:", free / (1024*1024), "MB") print("Total memory:", total / (1024*1024), "MB") except: print("Failed to get memory information") ``` **Returns:** [`Tuple`](/mojo/stdlib/builtin/tuple/Tuple): A tuple of (free memory, total memory) in bytes. **Raises:** If there's an error retrieving the memory information. ### `can_access` `can_access(self, peer: Self) -> Bool` Returns True if this device can access the identified peer device. This method checks whether the current device can directly access memory on the specified peer device. Peer-to-peer access allows for direct memory transfers between devices without going through host memory, which can significantly improve performance in multi-GPU scenarios. Example: ```mojo from gpu.host import DeviceContext var ctx1 = DeviceContext(0) # First GPU var ctx2 = DeviceContext(1) # Second GPU try: if ctx1.can_access(ctx2): print("Direct peer access is possible") ctx1.enable_peer_access(ctx2) else: print("Direct peer access is not supported") except: print("Failed to check peer access capability") ``` **Args:** * ​peer (`Self`): The peer device to check for accessibility. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the current device can access the peer device, False otherwise. **Raises:** If there's an error checking peer access capability. ### `enable_peer_access` `enable_peer_access(self, peer: Self)` Enables direct memory access to the peer device. This method establishes peer-to-peer access from the current device to the specified peer device. Once enabled, the current device can directly read from and write to memory allocated on the peer device without going through host memory, which can significantly improve performance for multi-GPU operations. Notes: * It's recommended to call `can_access()` first to check if peer access is possible. * Peer access is not always symmetric; you may need to enable access in both directions. Example: ```mojo from gpu.host import DeviceContext var ctx1 = DeviceContext(0) # First GPU var ctx2 = DeviceContext(1) # Second GPU try: if ctx1.can_access(ctx2): ctx1.enable_peer_access(ctx2) print("Peer access enabled from device 0 to device 1") # For bidirectional access if ctx2.can_access(ctx1): ctx2.enable_peer_access(ctx1) print("Peer access enabled from device 1 to device 0") else: print("Peer access not supported between these devices") except: print("Failed to enable peer access") ``` **Args:** * ​peer (`Self`): The peer device to enable access to. **Raises:** If there's an error enabling peer access or if peer access is not supported between the devices. ### `supports_multicast` `supports_multicast(self) -> Bool` Returns True if this device supports multicast memory mappings. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the current device supports multicast memory, False otherwise. **Raises:** If there's an error checking peer access capability. ### `number_of_devices` `static number_of_devices(*, api: String = DeviceContext.default_device_info.api) -> Int` Returns the number of devices available that support the specified API. This function queries the system for available devices that support the requested API (such as CUDA or HIP). It's useful for determining how many accelerators are available before allocating resources or distributing work. Example: ```mojo from gpu.host import DeviceContext # Get number of CUDA devices var num_cuda_devices = DeviceContext.number_of_devices(api="cuda") # Get number of devices for the default API var num_devices = DeviceContext.number_of_devices() ``` **Args:** * ​api ([`String`](/mojo/stdlib/collections/string/string/String)): Requested device API (for example, "cuda" or "hip"). Defaults to the device API specified by current target accelerator. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of available devices supporting the specified API. ### `enable_all_peer_access` `static enable_all_peer_access()` Enable peer-to-peer memory access between all available accelerators. This function detects all available accelerators in the system and enables peer-to-peer (P2P) memory access between every pair of devices. When peer access is enabled, kernels running on one device can directly access memory allocated on another device without going through host memory. This is crucial for efficient multi-GPU operations like allreduce. The function is a no-op when: * No accelerators are available * Only one accelerator is available * Peer access is already enabled between devices Example: ```mojo from gpu.host import DeviceContext # Enable P2P access between all GPUs DeviceContext.enable_all_peer_access() # Now GPUs can directly access each other's memory ``` **Raises:** If peer access cannot be enabled between any pair of devices. This can happen if the hardware doesn't support P2P access or if there's a configuration issue.
--- ## DeviceEvent
`struct DeviceEvent` Represents a GPU event for synchronization between streams. A DeviceEvent allows for fine-grained synchronization between different GPU streams. Events can be recorded in one stream and waited for in another, enabling efficient coordination of asynchronous GPU operations. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() var default_stream = ctx.stream() var new_stream = ctx.create_stream() # Create event in default_stream var event = ctx.create_event() # Wait for the event in new_stream new_stream.enqueue_wait_for(event) # Stream 2 can continue default_stream.record_event(event) ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__del__` `__del__(deinit self)` Releases resources associated with this event. ### `synchronize` `synchronize(self)` Blocks the calling CPU thread until this event completes. This function waits until the event has been recorded and all operations before the event in the stream have completed. **Raises:** If synchronization fails.
--- ## DeviceExternalFunction
`struct DeviceExternalFunction` Represents an external device function loaded from PTX/SASS assembly. This class provides functionality to load and execute pre-compiled GPU functions from assembly code rather than compiling them from Mojo source. This is useful for integrating with existing CUDA/HIP code or for using specialized assembly optimizations. The `DeviceExternalFunction` handles reference counting of the underlying device function handle and provides methods for launching the function on a GPU with specified execution configuration. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = False` ## Methods ### `__copyinit__` `__copyinit__(out self, existing: Self)` Creates a copy of an existing device function by incrementing its reference count. **Args:** * ​existing (`Self`): The device function to copy. ### `__del__` `__del__(deinit self)` Releases resources associated with this device function. ### `get_attribute` `get_attribute(self, attr: Attribute) -> Int` Retrieves a specific attribute of this device function. **Args:** * ​attr ([`Attribute`](/mojo/stdlib/gpu/host/func_attribute/Attribute)): The attribute to query. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value of the requested attribute. **Raises:** If the attribute query fails.
--- ## DeviceFunction
`struct DeviceFunction[func_type: AnyTrivialRegType, //, func: func_type, declared_arg_types: Optional[Variadic[AnyType]], *, target: __mlir_type.`!kgen.target` = get_gpu_target(), compile_options: StringSlice[StaticConstantOrigin] = CompilationTarget.default_compile_options[target](), _ptxas_info_verbose: Bool = False]` Represents a compiled device function for GPU execution. This struct encapsulates a compiled GPU function that can be launched on a device. It handles the compilation, loading, and resource management of device functions. Example: ```mojo from gpu.host import DeviceContext, DeviceFunction fn my_kernel(x: Int, y: Int): # Kernel implementation pass var ctx = DeviceContext() var kernel = ctx.compile_function_checked[my_kernel, my_kernel]() ctx.enqueue_function_checked(kernel, grid_dim=(1,1,1), block_dim=(32,1,1)) ``` ## Parameters * ​func\_type ([`AnyTrivialRegType`](/mojo/stdlib/builtin/type_aliases/#anytrivialregtype)): The dtype of the function to compile. * ​func (`func_type`): The function to compile for GPU execution. * ​declared\_arg\_types ([`Optional`](/mojo/stdlib/collections/optional/Optional)): An optional containing a variadic of the declared dtypes of the kernel signature. * ​target (`__mlir_type.`!kgen.target\`\`): The target architecture for compilation. Defaults to the current GPU target. * ​compile\_options ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The string of compilation options to pass to the compiler. * ​\_ptxas\_info\_verbose ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether to enable verbose PTX assembly output. Defaults to False. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__copyinit__` `__copyinit__(out self, existing: Self)` Creates a copy of an existing DeviceFunction. This increases the reference count of the underlying device function handle. **Args:** * ​existing (`Self`): The DeviceFunction to copy from. ### `__del__` `__del__(deinit self)` Releases resources associated with this DeviceFunction. This decrements the reference count of the underlying device function handle. ### `dump_rep` `dump_rep[dump_asm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, dump_llvm: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False, _dump_sass: Variant[Bool, Path, StaticString, fn() capturing -> Path] = False](self)` Dumps various representations of the compiled device function. This method dumps the assembly, LLVM IR, and/or SASS code for the compiled device function based on the provided parameters. The output can be directed to stdout or written to files. Notes: When a path contains '%', it will be replaced with the module name to help disambiguate multiple kernel dumps. **Parameters:** * ​dump\_asm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Controls dumping of assembly code. Can be a boolean, a file path, or a function returning a file path. * ​dump\_llvm ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Controls dumping of LLVM IR. Can be a boolean, a file path, or a function returning a file path. * ​\_dump\_sass ([`Variant`](/mojo/stdlib/utils/variant/Variant)): Controls dumping of SASS code (internal use). Can be a boolean, a file path, or a function returning a file path. **Raises:** If any file operations fail during the dumping process. ### `get_attribute` `get_attribute(self, attr: Attribute) -> Int` Retrieves a specific attribute value from the compiled device function. This method queries the device function for information about its resource requirements, execution capabilities, or other properties defined by the specified attribute. Example: ```mojo from gpu.host import Attribute, DeviceFunction var device_function = DeviceFunction(...) # Get the maximum number of threads per block for this function var max_threads = device_function.get_attribute(Attribute.MAX_THREADS_PER_BLOCK) ``` **Args:** * ​attr ([`Attribute`](/mojo/stdlib/gpu/host/func_attribute/Attribute)): The attribute to query, defined in the Attribute enum. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The integer value of the requested attribute. **Raises:** If the attribute query fails or the attribute is not supported. ### `occupancy_max_active_blocks_per_multiprocessor` `occupancy_max_active_blocks_per_multiprocessor(self, block_size: Int, dynamic_shared_mem_size: Int) -> Int` Returns the maximum number of active blocks per multiprocessor for the given function. **Args:** * ​block\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of threads per block. * ​dynamic\_shared\_mem\_size ([`Int`](/mojo/stdlib/builtin/int/Int)): The size of dynamically allocated shared memory in bytes. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The maximum number of active blocks that can run concurrently per multiprocessor. **Raises:** If the occupancy calculation fails.
--- ## DeviceMulticastBuffer
`struct DeviceMulticastBuffer[dtype: DType]` Represents a multicast memory object enables special memory operations to be broadcast across a group of devices. ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data dtype to be stored in the associated memory regions. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True`
--- ## DeviceStream
`struct DeviceStream` Represents a CUDA/HIP stream for asynchronous GPU operations. A DeviceStream provides a queue for GPU operations that can execute concurrently with operations in other streams. Operations within a single stream execute in the order they are issued, but operations in different streams may execute in any relative order or concurrently. This abstraction allows for better utilization of GPU resources by enabling overlapping of computation and data transfers. Example: ```mojo from gpu.host import DeviceContext, DeviceStream var ctx = DeviceContext(0) # Select first GPU var stream = DeviceStream(ctx) # Launch operations on the stream # ... # Wait for all operations in the stream to complete stream.synchronize() ``` ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `synchronize` `synchronize(self)` Blocks the calling CPU thread until all operations in this stream complete. This function waits until all previously issued commands in this stream have completed execution. It provides a synchronization point between host and device code. Example: ```mojo # Launch kernel or memory operations on the stream # ... # Wait for completion stream.synchronize() # Now it's safe to use results on the host ``` **Raises:** If synchronization fails. ### `enqueue_wait_for` `enqueue_wait_for(self, event: DeviceEvent)` Makes this stream wait for the specified event. This function inserts a wait operation into this stream that will block all subsequent operations in the stream until the specified event has been recorded and completed. **Args:** * ​event ([`DeviceEvent`](/mojo/stdlib/gpu/host/device_context/DeviceEvent)): The event to wait for. **Raises:** If the wait operation fails. ### `record_event` `record_event(self, event: DeviceEvent)` Records an event in this stream. This function records the given event at the current point in this stream. All operations in the stream that were enqueued before this call will complete before the event is triggered. Example: ```mojo from gpu.host import DeviceContext var ctx = DeviceContext() var default_stream = ctx.stream() var new_stream = ctx.create_stream() # Create event on the default stream var event = default_stream.create_event() # Wait for the event on the new stream new_stream.enqueue_wait_for(event) # Stream 2 can continue default_stream.record_event(event) ``` **Args:** * ​event ([`DeviceEvent`](/mojo/stdlib/gpu/host/device_context/DeviceEvent)): The event to record. **Raises:** If event recording fails. ### `enqueue_function` `enqueue_function[*Ts: AnyType](self, f: DeviceFunction[func, declared_arg_types, target=target, compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose], *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()))` Enqueues a compiled function for execution on this device. You can pass the function directly to `enqueue_function` without compiling it first: ```mojo from gpu.host import DeviceContext fn kernel(): print("hello from the GPU") with DeviceContext() as ctx: ctx.enqueue_function[kernel](grid_dim=1, block_dim=1) ctx.synchronize() ``` If you are reusing the same function and parameters multiple times, this incurs 50-500 nanoseconds of overhead per enqueue, so you can compile the function first to remove the overhead: ```mojo from gpu.host import DeviceContext with DeviceContext() as ctx: var compiled_func = ctx.compile_function_checked[kernel, kernel]() ctx.enqueue_function_checked(compiled_func, grid_dim=1, block_dim=1) ctx.enqueue_function_checked(compiled_func, grid_dim=1, block_dim=1) ctx.synchronize() ``` **Parameters:** * ​\*Ts ([`AnyType`](/mojo/stdlib/builtin/anytype/AnyType)): Argument dtypes. **Args:** * ​f ([`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction)): The compiled function to execute. * ​\*args (`*Ts`): Arguments to pass to the function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of the compute grid, made up of thread blocks. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of each thread block in the grid. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Dimensions of clusters (if the thread blocks are grouped into clusters). * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Amount of shared memory per thread block. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): Launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): Constant memory mapping. **Raises:** If the operation fails. ### `enqueue_function_checked` `enqueue_function_checked[*Ts: DevicePassable](self, f: DeviceFunction[func, declared_arg_types, target=target, compile_options=compile_options, _ptxas_info_verbose=_ptxas_info_verbose], *args: *Ts, *, grid_dim: Dim, block_dim: Dim, cluster_dim: OptionalReg[Dim] = None, shared_mem_bytes: OptionalReg[Int] = None, var attributes: List[LaunchAttribute] = List[LaunchAttribute](, Tuple[]()), var constant_memory: List[ConstantMemoryMapping] = List[ConstantMemoryMapping](, Tuple[]()))` Enqueues a checked compiled function for execution on this stream. **Parameters:** * ​\*Ts ([`DevicePassable`](/mojo/stdlib/builtin/device_passable/DevicePassable)): Argument types (must be DevicePassable). **Args:** * ​f ([`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction)): The checked compiled function to execute. * ​\*args (`*Ts`): Arguments to pass to the function. * ​grid\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of the compute grid, made up of thread blocks. * ​block\_dim ([`Dim`](/mojo/stdlib/gpu/host/dim/Dim)): Dimensions of each thread block in the grid. * ​cluster\_dim ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Dimensions of clusters (if the thread blocks are grouped into clusters). * ​shared\_mem\_bytes ([`OptionalReg`](/mojo/stdlib/collections/optional/OptionalReg)): Amount of shared memory per thread block. * ​attributes ([`List`](/mojo/stdlib/collections/list/List)): Launch attributes. * ​constant\_memory ([`List`](/mojo/stdlib/collections/list/List)): Constant memory mapping. **Raises:** If the operation fails.
--- ## EventFlags
`@register_passable(trivial)` `struct EventFlags` Provides flags for creating events. These flags can be combined using the bitwise OR operator (`|`, `|=`). ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `blocking_sync` `comptime blocking_sync = EventFlags(1)` Allows `event.synchronize()` to block until the event has been recorded. ### `default` `comptime default = EventFlags(0)` Default event flags, with timing enabled. ### `disable_timing` `comptime disable_timing = EventFlags(2)` Removes timing overhead. ### `interprocess` `comptime interprocess = EventFlags(4)` Enable interprocess synchronization, currently unimplemented. ## Methods ### `__init__` `__init__(flags: UInt32) -> Self` Initializes a new EventFlags. **Args:** * ​flags ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): The flags to initialize the EventFlags with. ### `__or__` `__or__(self, other: Self) -> Self` Returns the current flags combined with another flag. **Args:** * ​other (`Self`): The flag to combine with the current flags. **Returns:** `Self`: A new EventFlags instance with the combined flags. ### `__ior__` `__ior__(mut self, other: Self)` Combines the current flags with another flag in-place. **Args:** * ​other (`Self`): The flag to combine with the current flags.
--- ## HostBuffer
`struct HostBuffer[dtype: DType]` Represents a block of host-resident storage. For GPU devices, a host buffer is allocated in the host's global memory. To allocate a `HostBuffer`, use one of the methods provided by `DeviceContext`, such as [`enqueue_create_host_buffer()`](/mojo/stdlib/gpu/host/device_context/DeviceContext#enqueue_create_host_buffer). ## Parameters * ​dtype ([`DType`](/mojo/stdlib/builtin/dtype/DType)): Data type to be stored in the buffer. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Sized`](/mojo/stdlib/builtin/len/Sized), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = False` ### `__del__is_trivial` `comptime __del__is_trivial = False` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__copyinit__` `__copyinit__(out self, existing: Self)` Creates a copy of an existing host buffer by incrementing its reference count. This copy constructor creates a new reference to the same underlying host buffer by incrementing the reference count of the native buffer object. Both the original and the copy will refer to the same memory on the device. **Args:** * ​existing (`Self`): The host buffer to copy. ### `__del__` `__del__(deinit self)` Releases resources associated with this host buffer. This function schedules an owned buffer free using the stream in the device context. The actual deallocation may occur asynchronously after all operations using this buffer have completed. ### `__getitem__` `__getitem__(self, idx: Int) -> Scalar[dtype]` Retrieves the element at the specified index from the host buffer. This operator allows direct access to individual elements in the host buffer using array indexing syntax. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to retrieve. **Returns:** [`Scalar`](/mojo/stdlib/builtin/simd/#scalar): The scalar value at the specified index. ### `__setitem__` `__setitem__(self, idx: Int, val: Scalar[dtype])` Sets the element at the specified index in the host buffer. This operator allows direct modification of individual elements in the host buffer using array indexing syntax. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index of the element to modify. * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The new value to store at the specified index. ### `__len__` `__len__(self) -> Int` Returns the number of elements in this buffer. This method calculates the number of elements by dividing the total byte size of the buffer by the size of each element. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The number of elements in the buffer. ### `create_sub_buffer` `create_sub_buffer[view_type: DType](self, offset: Int, size: Int) -> HostBuffer[view_type]` Creates a sub-buffer view of this buffer with a different element dtype. This method creates a new buffer that references a subset of the memory in this buffer, potentially with a different element dtype. The sub-buffer shares the underlying memory with the original buffer. **Parameters:** * ​view\_type ([`DType`](/mojo/stdlib/builtin/dtype/DType)): The data type for elements in the new sub-buffer. **Args:** * ​offset ([`Int`](/mojo/stdlib/builtin/int/Int)): The starting offset in elements from the beginning of this buffer. * ​size ([`Int`](/mojo/stdlib/builtin/int/Int)): The number of elements in the new sub-buffer. **Returns:** `HostBuffer`: A new HostBuffer referencing the specified region with the specified element dtype. **Raises:** If the operation fails. ### `enqueue_copy_to` `enqueue_copy_to(self, dst: Self)` Enqueues an asynchronous copy from this buffer to another host buffer. This method schedules a memory copy operation from this buffer to the destination buffer. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​dst (`Self`): The destination host buffer to copy data to. **Raises:** If the operation fails. `enqueue_copy_to(self, dst: DeviceBuffer[dtype])` Enqueues an asynchronous copy from this buffer to a device buffer. This method schedules a memory copy operation from this buffer to the destination buffer. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​dst ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): The destination device buffer to copy data to. **Raises:** If the operation fails. `enqueue_copy_to(self, dst_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin])` Enqueues an asynchronous copy from this buffer to host memory. This method schedules a memory copy operation from this device buffer to the specified host memory location. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​dst\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to the destination host memory location. **Raises:** If the operation fails. ### `enqueue_copy_from` `enqueue_copy_from(self, src: Self)` Enqueues an asynchronous copy to this buffer from another host buffer. This method schedules a memory copy operation to this buffer from the source buffer. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​src (`Self`): The source host buffer to copy data from. **Raises:** If the operation fails. `enqueue_copy_from(self, src: DeviceBuffer[dtype])` Enqueues an asynchronous copy to this buffer from a device buffer. This method schedules a memory copy operation to this buffer from the source buffer. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​src ([`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer)): The source device buffer to copy data from. **Raises:** If the operation fails. `enqueue_copy_from(self, src_ptr: UnsafePointer[Scalar[dtype], origin])` Enqueues an asynchronous copy to this buffer from host memory. This method schedules a memory copy operation to this device buffer from the specified host memory location. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​src\_ptr ([`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer)): Pointer to the source host memory location. **Raises:** If the operation fails. ### `enqueue_fill` `enqueue_fill(self, val: Scalar[dtype])` Enqueues an operation to fill this buffer with a specified value. This method schedules a memory set operation that fills the entire buffer with the specified value. The operation is asynchronous and will be executed in the stream associated with this buffer's context. **Args:** * ​val ([`Scalar`](/mojo/stdlib/builtin/simd/#scalar)): The value to fill the buffer with. **Raises:** If the operation fails. ### `reassign_ownership_to` `reassign_ownership_to(self, ctx: DeviceContext)` Transfers ownership of this buffer to another device context. This method changes the device context that owns this buffer. This can be useful when sharing buffers between different contexts or when migrating workloads between devices. **Args:** * ​ctx ([`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext)): The new device context to take ownership of this buffer. **Raises:** If the operation fails. ### `take_ptr` `take_ptr(var self) -> UnsafePointer[Scalar[dtype], MutAnyOrigin]` Takes ownership of the device pointer from this buffer. This method releases the device pointer from the buffer's control and returns it to the caller. After this call, the buffer no longer owns the pointer, and the caller is responsible for managing its lifecycle. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The raw device pointer that was owned by this buffer. ### `unsafe_ptr` `unsafe_ptr(self) -> UnsafePointer[Scalar[dtype], MutAnyOrigin]` Returns the raw device pointer without transferring ownership. This method provides direct access to the underlying device pointer for advanced use cases. The buffer retains ownership of the pointer. **Returns:** [`UnsafePointer`](/mojo/stdlib/memory/unsafe_pointer/UnsafePointer): The raw device pointer owned by this buffer. ### `context` `context(self) -> DeviceContext` Returns the device context associated with this buffer. This method retrieves the device context that owns this buffer and is responsible for managing its lifecycle and operations. **Returns:** [`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext): The device context associated with this buffer. **Raises:** If the operation fails. ### `write_to` `write_to(self, mut writer: T)` Writes a string representation of this buffer to the provided writer. This method formats the buffer's contents as a string and writes it to the specified writer. For large buffers, a compact representation is used. **Args:** * ​writer (`T`): The writer to output the formatted string to. ### `__str__` `__str__(self) -> String` Returns a string representation of the `HostBuffer`. This method creates a human-readable string representation of the buffer's contents by mapping the device memory to host memory and formatting the elements. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string containing the formatted buffer contents. ### `as_span` `as_span[mut: Bool, origin: Origin[mut], //](ref [origin] self) -> Span[Scalar[dtype], origin]` Returns a `Span` pointing to the underlying memory of the `HostBuffer`. **Parameters:** * ​mut ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Whether the span should be mutable. * ​origin ([`Origin`](/mojo/stdlib/builtin/type_aliases/Origin)): The origin of the buffer reference. **Returns:** [`Span`](/mojo/stdlib/memory/span/Span): A `Span` pointing to the underlying memory of the `HostBuffer`.
--- ## StreamPriorityRange
`@register_passable(trivial)` `struct StreamPriorityRange` Represents the range of valid stream priorities for a GPU device. Stream priorities control the scheduling of GPU operations, with higher priority streams being executed preferentially over lower priority streams. ## Fields * ​least (`Int`): The lowest (numerically smallest) priority value. * ​greatest (`Int`): The highest (numerically largest) priority value. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__str__` `__str__(self) -> String` Returns a string representation of the stream priority range. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): A string in the format "StreamPriorityRange(least=X, greatest=Y)". ### `write_to` `write_to(self, mut writer: T)` Writes the stream priority range to the given writer. **Args:** * ​writer (`T`): The writer to output the stream priority range to.
--- ## device_context
This module provides functionality for interacting with accelerators. In particular the [`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext) struct, which represents a single stream of execution on a given accelerator. You can use this struct to allocate accelerator memory, copy data to and from the accelerator, and compile and execute functions on the accelerator. ## Structs * [​`DeviceBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceBuffer): Represents a block of device-resident storage. For GPU devices, a device buffer is allocated in the device's global memory. * [​`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext): Represents a single stream of execution on a particular accelerator (GPU). * [​`DeviceEvent`](/mojo/stdlib/gpu/host/device_context/DeviceEvent): Represents a GPU event for synchronization between streams. * [​`DeviceExternalFunction`](/mojo/stdlib/gpu/host/device_context/DeviceExternalFunction): Represents an external device function loaded from PTX/SASS assembly. * [​`DeviceFunction`](/mojo/stdlib/gpu/host/device_context/DeviceFunction): Represents a compiled device function for GPU execution. * [​`DeviceMulticastBuffer`](/mojo/stdlib/gpu/host/device_context/DeviceMulticastBuffer): Represents a multicast memory object enables special memory operations to be broadcast across a group of devices. * [​`DeviceStream`](/mojo/stdlib/gpu/host/device_context/DeviceStream): Represents a CUDA/HIP stream for asynchronous GPU operations. * [​`EventFlags`](/mojo/stdlib/gpu/host/device_context/EventFlags): Provides flags for creating events. * [​`HostBuffer`](/mojo/stdlib/gpu/host/device_context/HostBuffer): Represents a block of host-resident storage. For GPU devices, a host buffer is allocated in the host's global memory. * [​`StreamPriorityRange`](/mojo/stdlib/gpu/host/device_context/StreamPriorityRange): Represents the range of valid stream priorities for a GPU device.
--- ## Dim
`@register_passable(trivial)` `struct Dim` Represents a dimension with up to three components (x, y, z). This struct is commonly used to represent grid and block dimensions for kernel launches. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__init__` `@implicit` `__init__[I: Indexer, //](x: I) -> Self` Initializes Dim with a single indexable value for x. y and z dimensions are set to 1. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the indexable value. **Args:** * ​x (`I`): The value for the x dimension. `__init__[I0: Indexer, I1: Indexer, //](x: I0, y: I1) -> Self` Initializes Dim with indexable values for x and y. z dimension is set to 1. **Parameters:** * ​I0 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the first indexable value. * ​I1 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the second indexable value. **Args:** * ​x (`I0`): The value for the x dimension. * ​y (`I1`): The value for the y dimension. `__init__[I0: Indexer, I1: Indexer, I2: Indexer, //](x: I0, y: I1, z: I2) -> Self` Initializes Dim with indexable values for x, y, and z. **Parameters:** * ​I0 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the first indexable value. * ​I1 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the second indexable value. * ​I2 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer)): The type of the third indexable value. **Args:** * ​x (`I0`): The value for the x dimension. * ​y (`I1`): The value for the y dimension. * ​z (`I2`): The value for the z dimension. `@implicit` `__init__[I: Indexer & Copyable, //](dims: Tuple[I]) -> Self` Initializes Dim with a tuple containing a single indexable value. y and z dimensions are set to 1. **Parameters:** * ​I ([`Indexer`](/mojo/stdlib/builtin/int/Indexer) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the indexable value in the tuple. **Args:** * ​dims ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): A tuple with one element for x dimension. `@implicit` `__init__[I0: Indexer & Copyable, I1: Indexer & Copyable, //](dims: Tuple[I0, I1]) -> Self` Initializes Dim with a tuple of two indexable values. The z dimension is set to 1. **Parameters:** * ​I0 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the first indexable value in the tuple. * ​I1 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the second indexable value in the tuple. **Args:** * ​dims ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): A tuple with two elements: x and y dimensions. `@implicit` `__init__[I0: Indexer & Copyable, I1: Indexer & Copyable, I2: Indexer & Copyable, //](dims: Tuple[I0, I1, I2]) -> Self` Initializes Dim with a tuple of three indexable values. **Parameters:** * ​I0 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the first indexable value in the tuple. * ​I1 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the second indexable value in the tuple. * ​I2 ([`Indexer`](/mojo/stdlib/builtin/int/Indexer) & [`Copyable`](/mojo/stdlib/builtin/value/Copyable)): The type of the third indexable value in the tuple. **Args:** * ​dims ([`Tuple`](/mojo/stdlib/builtin/tuple/Tuple)): Tuple with three elements: x, y, and z dimensions. ### `__getitem__` `__getitem__(self, idx: Int) -> Int` Gets the dimension value at the specified index. **Args:** * ​idx ([`Int`](/mojo/stdlib/builtin/int/Int)): The index (0 for x, 1 for y, 2 for z). **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value of the dimension at the given index. ### `__str__` `__str__(self) -> String` Returns a string representation of the Dim. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String representation of this Dim object. ### `__repr__` `__repr__(self) -> String` Returns a string representation of the Dim. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String representation of this Dim object. ### `write_to` `write_to(self, mut writer: T)` Writes a formatted string representation of the Dim. **Args:** * ​writer (`T`): The Writer to write to. ### `z` `z(self) -> Int` Returns the z dimension. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value of the z dimension. ### `y` `y(self) -> Int` Returns the y dimension. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value of the y dimension. ### `x` `x(self) -> Int` Returns the x dimension. **Returns:** [`Int`](/mojo/stdlib/builtin/int/Int): The value of the x dimension.
--- ## dim (Dim)
This module implements the dim type. ## Structs * [​`Dim`](/mojo/stdlib/gpu/host/dim/Dim): Represents a dimension with up to three components (x, y, z).
--- ## Attribute
`@register_passable(trivial)` `struct Attribute` Represents GPU kernel function attributes. This struct defines constants for various function attributes that can be queried or set for GPU kernels. These attributes provide information about resource requirements and execution constraints of kernel functions. ## Fields * ​code (`Int32`): The numeric code representing the attribute type. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Identifiable`](/mojo/stdlib/builtin/identifiable/Identifiable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `BINARY_VERSION` `comptime BINARY_VERSION = Attribute(6)` The binary architecture version for which the function was compiled. This value is the major binary version \* 10 + the minor binary version, so a binary version 1.3 function would return the value 13. Note that this will return a value of 10 for legacy cubins that do not have a properly- encoded binary architecture version.. ### `CACHE_MODE_CA` `comptime CACHE_MODE_CA = Attribute(7)` The attribute to indicate whether the function has been compiled with user specified option "-Xptxas --dlcm=ca" set . ### `CLUSTER_SCHEDULING_POLICY_PREFERENCE` `comptime CLUSTER_SCHEDULING_POLICY_PREFERENCE = Attribute(15)` The block scheduling policy of a function. The value type is CUclusterSchedulingPolicy / cudaClusterSchedulingPolicy. ### `CLUSTER_SIZE_MUST_BE_SET` `comptime CLUSTER_SIZE_MUST_BE_SET = Attribute(10)` If this attribute is set, the kernel must launch with a valid cluster size specified. ### `CONST_SIZE_BYTES` `comptime CONST_SIZE_BYTES = Attribute(2)` The size in bytes of user-allocated constant memory required by this function. ### `LOCAL_SIZE_BYTES` `comptime LOCAL_SIZE_BYTES = Attribute(3)` The size in bytes of local memory used by each thread of this function. ### `MAX_DYNAMIC_SHARED_SIZE_BYTES` `comptime MAX_DYNAMIC_SHARED_SIZE_BYTES = Attribute(8)` The maximum size in bytes of dynamically-allocated shared memory that can be used by this function. If the user-specified dynamic shared memory size is larger than this value. ### `MAX_THREADS_PER_BLOCK` `comptime MAX_THREADS_PER_BLOCK = Attribute(0)` The maximum number of threads per block, beyond which a launch of the function would fail. This number depends on both the function and the device on which the function is currently loaded. ### `NON_PORTABLE_CLUSTER_SIZE_ALLOWED` `comptime NON_PORTABLE_CLUSTER_SIZE_ALLOWED = Attribute(14)` Whether the function can be launched with non-portable cluster size. 1 is allowed, 0 is disallowed. A non-portable cluster size may only function on the specific SKUs the program is tested on. The launch might fail if the program is run on a different hardware platform.CUDA API provides cudaOccupancyMaxActiveClusters to assist with checking whether the desired size can be launched on the current device.Portable Cluster SizeA portable cluster size is guaranteed to be functional on all compute capabilities higher than the target compute capability. The portable cluster size for sm\_90 is 8 blocks per cluster. ### `NUM_REGS` `comptime NUM_REGS = Attribute(4)` The number of registers used by each thread of this function. ### `PREFERRED_SHARED_MEMORY_CARVEOUT` `comptime PREFERRED_SHARED_MEMORY_CARVEOUT = Attribute(9)` On devices where the L1 cache and shared memory use the same hardware resources, this sets the shared memory carveout preference, in percent of the total shared memory. ### `PTX_VERSION` `comptime PTX_VERSION = Attribute(5)` The PTX virtual architecture version for which the function was compiled. This value is the major PTX version \* 10 + the minor PTX version, so a PTX version 1.3 function would return the value 13. Note that this may return the undefined value of 0 for cubins compiled prior to CUDA 3.0.. ### `REQUIRED_CLUSTER_DEPTH` `comptime REQUIRED_CLUSTER_DEPTH = Attribute(13)` The required cluster depth in blocks. The values must either all be 0 or all be positive. The validity of the cluster dimensions is otherwise checked at launch time. ### `REQUIRED_CLUSTER_HEIGHT` `comptime REQUIRED_CLUSTER_HEIGHT = Attribute(12)` The required cluster height in blocks. The values must either all be 0 or all be positive. The validity of the cluster dimensions is otherwise checked at launch time. ### `REQUIRED_CLUSTER_WIDTH` `comptime REQUIRED_CLUSTER_WIDTH = Attribute(11)` The required cluster width in blocks. The values must either all be 0 or all be positive. The validity of the cluster dimensions is otherwise checked at launch time. ### `SHARED_SIZE_BYTES` `comptime SHARED_SIZE_BYTES = Attribute(1)` The size in bytes of statically-allocated shared memory required by this function. This does not include dynamically-allocated shared memory requested by the user at runtime. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` Checks if two Attribute instances are equal. **Args:** * ​other (`Self`): The Attribute to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if both attributes have the same code, False otherwise. ### `__ne__` `__ne__(self, other: Self) -> Bool` Checks if two Attribute instances are not equal. **Args:** * ​other (`Self`): The Attribute to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if the attributes have different codes, False otherwise. ### `__is__` `__is__(self, other: Self) -> Bool` Identity comparison operator for Attribute instances. **Args:** * ​other (`Self`): The Attribute to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if both attributes are identical (have the same code), False otherwise. ### `write_to` `write_to(self, mut writer: T)` Writes a string representation of the `Attribute` to the provided writer. ``` This method converts the `Attribute` enum value to its corresponding string name and writes it to the provided writer object. ``` **Args:** * ​writer (`T`): A Writer object that will receive the string representation.
--- ## FuncAttribute
`@register_passable(trivial)` `struct FuncAttribute` Implements CUDA's CUfunction\_attribute enum for GPU kernel function attributes. This struct represents function attributes that can be set or queried for GPU kernels, following NVIDIA's CUDA driver API conventions. Each attribute consists of a type (represented by the Attribute enum) and an associated value. The struct provides factory methods for creating common attribute configurations, such as cache mode settings and shared memory allocations. Reference: ## Fields * ​attribute (`Attribute`): The type of function attribute. * ​value (`Int32`): The value associated with this attribute. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `NULL` `comptime NULL = FuncAttribute(Attribute(-1), -1)` A null/invalid function attribute constant. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` Checks if two `FuncAttribute` instances are equal. **Args:** * ​other (`Self`): The FuncAttribute to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if both the attribute type and value are equal, False otherwise. ### `__ne__` `__ne__(self, other: Self) -> Bool` Checks if two `FuncAttribute` instances are not equal. **Args:** * ​other (`Self`): The `FuncAttribute` to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if either the attribute type or value differs, False otherwise. ### `CACHE_MODE_CA` `static CACHE_MODE_CA(val: Bool) -> Self` Creates a CACHE\_MODE\_CA function attribute. Indicates whether the function has been compiled with user specified option `CacheMode.L1_CACHE_DISABLED` set. **Args:** * ​val ([`Bool`](/mojo/stdlib/builtin/bool/Bool)): Boolean value indicating if L1 cache is disabled. **Returns:** `Self`: A `FuncAttribute` instance with CACHE\_MODE\_CA attribute type. ### `MAX_DYNAMIC_SHARED_SIZE_BYTES` `static MAX_DYNAMIC_SHARED_SIZE_BYTES(val: UInt32) -> Self` Creates a MAX\_DYNAMIC\_SHARED\_SIZE\_BYTES function attribute. The maximum size in bytes of dynamically-allocated shared memory that can be used by this function. If the user-specified dynamic shared memory size is larger than this value, the launch will fail. **Args:** * ​val ([`UInt32`](/mojo/stdlib/builtin/simd/#uint32)): Maximum dynamic shared memory size in bytes. **Returns:** `Self`: A `FuncAttribute` instance with `MAX_DYNAMIC_SHARED_SIZE_BYTES` attribute type. ### `PREFERRED_SHARED_MEMORY_CARVEOUT` `static PREFERRED_SHARED_MEMORY_CARVEOUT(val: Int32) -> Self` Creates a PREFERRED\_SHARED\_MEMORY\_CARVEOUT function attribute. On devices where the L1 cache and shared memory use the same hardware resources, this sets the shared memory carveout preference, in percent of the total shared memory. **Args:** * ​val ([`Int32`](/mojo/stdlib/builtin/simd/#int32)): Shared memory carveout preference as a percentage (0-100). **Returns:** `Self`: A FuncAttribute instance with `PREFERRED_SHARED_MEMORY_CARVEOUT` attribute type.
--- ## func_attribute
GPU Kernel Function Attributes Module This module provides structures for defining and managing GPU kernel function attributes. It implements functionality similar to CUDA's CUfunction\_attribute enum, allowing for querying and setting various attributes that control kernel execution behavior and resource allocation. The module includes: * `Attribute`: A value type representing different GPU kernel function attribute types * `FuncAttribute`: A structure that pairs an attribute type with its value These structures enable fine-grained control over GPU kernel execution parameters such as shared memory allocation, cache behavior, and cluster configuration. ## Structs * [​`Attribute`](/mojo/stdlib/gpu/host/func_attribute/Attribute): Represents GPU kernel function attributes. * [​`FuncAttribute`](/mojo/stdlib/gpu/host/func_attribute/FuncAttribute): Implements CUDA's CUfunction\_attribute enum for GPU kernel function attributes.
--- ## host
Implements the gpu host package. ## Packages * [​`nvidia`](/mojo/stdlib/gpu/host/nvidia/): Implements the tma package. ## Modules * [​`compile`](/mojo/stdlib/gpu/host/compile/): Implements CUDA compilation operations. * [​`constant_memory_mapping`](/mojo/stdlib/gpu/host/constant_memory_mapping/): This module provides functionality for mapping constant memory between host and device. * [​`device_attribute`](/mojo/stdlib/gpu/host/device_attribute/): This module defines GPU device attributes that can be queried from CUDA-compatible devices. * [​`device_context`](/mojo/stdlib/gpu/host/device_context/): This module provides functionality for interacting with accelerators. In particular the [`DeviceContext`](/mojo/stdlib/gpu/host/device_context/DeviceContext) struct, which represents a single stream of execution on a given accelerator. You can use this struct to allocate accelerator memory, copy data to and from the accelerator, and compile and execute functions on the accelerator. * [​`dim`](/mojo/stdlib/gpu/host/dim/): This module implements the dim type. * [​`func_attribute`](/mojo/stdlib/gpu/host/func_attribute/): GPU Kernel Function Attributes Module * [​`info`](/mojo/stdlib/gpu/host/info/): Contains information about GPU architectures and their capabilities. * [​`launch_attribute`](/mojo/stdlib/gpu/host/launch_attribute/): GPU Launch Attributes for Kernel Execution Control
--- ## AcceleratorArchitectureFamily
`@register_passable(trivial)` `struct AcceleratorArchitectureFamily` Defines common defaults for a GPU architecture family. This struct captures the shared characteristics across GPUs in the same architecture family, reducing redundancy when defining new GPU models. ## Fields * ​warp\_size (`Int`): Number of threads in a warp/wavefront. * ​threads\_per\_multiprocessor (`Int`): Maximum number of threads per streaming multiprocessor. * ​shared\_memory\_per\_multiprocessor (`Int`): Size of shared memory available per multiprocessor in bytes. * ​max\_registers\_per\_block (`Int`): Maximum number of registers that can be allocated to a thread block. * ​max\_thread\_block\_size (`Int`): Maximum number of threads allowed in a thread block. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True`
--- ## GPUInfo
`@register_passable` `struct GPUInfo` Comprehensive information about a GPU architecture. This struct contains detailed specifications about GPU capabilities, including compute units, memory, thread organization, and performance characteristics. ## Fields * ​name (`StaticString`): The model name of the GPU. * ​vendor (`Vendor`): The vendor/manufacturer of the GPU (e.g., NVIDIA, AMD). * ​api (`StaticString`): The graphics/compute API supported by the GPU (e.g., CUDA, ROCm). * ​arch\_name (`StaticString`): The architecture name of the GPU (e.g., sm\_80, gfx942). * ​compute (`Float32`): Compute capability version number for NVIDIA GPUs. * ​version (`StaticString`): Version string of the GPU architecture. * ​sm\_count (`Int`): Number of streaming multiprocessors (SMs) on the GPU. * ​warp\_size (`Int`): Number of threads in a warp/wavefront. * ​threads\_per\_multiprocessor (`Int`): Maximum number of threads per streaming multiprocessor. * ​shared\_memory\_per\_multiprocessor (`Int`): Size of shared memory available per multiprocessor in bytes. * ​max\_registers\_per\_block (`Int`): Maximum number of registers that can be allocated to a thread block. * ​max\_thread\_block\_size (`Int`): Maximum number of threads allowed in a thread block. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Equatable`](/mojo/stdlib/builtin/comparable/Equatable), [`Identifiable`](/mojo/stdlib/builtin/identifiable/Identifiable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`Stringable`](/mojo/stdlib/builtin/str/Stringable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` Checks if two `GPUInfo` instances represent the same GPU model. **Args:** * ​other (`Self`): Another `GPUInfo` instance to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if both instances represent the same GPU model. ### `__is__` `__is__(self, other: Self) -> Bool` Identity comparison operator for `GPUInfo` instances. **Args:** * ​other (`Self`): Another `GPUInfo` instance to compare against. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if both instances represent the same GPU model. ### `target` `target(self) -> __mlir_type.`!kgen.target\`\` Gets the MLIR target configuration for this GPU. **Returns:** `__mlir_type.`!kgen.target\`\`: MLIR target configuration for the GPU. ### `from_target` `static from_target[target: __mlir_type.`!kgen.target`]() -> Self` Creates a `GPUInfo` instance from an MLIR target. **Parameters:** * ​target (`__mlir_type.`!kgen.target\`\`): MLIR target configuration. **Returns:** `Self`: GPU info corresponding to the target. ### `from_name` `static from_name[name: StringSlice[StaticConstantOrigin]]() -> Self` Creates a `GPUInfo` instance from a GPU architecture name. **Parameters:** * ​name ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): GPU architecture name (e.g., "sm\_80", "gfx942"). **Returns:** `Self`: GPU info corresponding to the architecture name. ### `from_family` `static from_family(family: AcceleratorArchitectureFamily, name: StringSlice[StaticConstantOrigin], vendor: Vendor, api: StringSlice[StaticConstantOrigin], arch_name: StringSlice[StaticConstantOrigin], compute: Float32, version: StringSlice[StaticConstantOrigin], sm_count: Int) -> Self` Creates a `GPUInfo` instance using architecture family defaults. This constructor simplifies GPU definition by inheriting common characteristics from an architecture family while allowing specific values to be overridden. **Args:** * ​family ([`AcceleratorArchitectureFamily`](/mojo/stdlib/gpu/host/info/AcceleratorArchitectureFamily)): Architecture family providing default values. * ​name ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The model name of the GPU. * ​vendor ([`Vendor`](/mojo/stdlib/gpu/host/info/Vendor)): The vendor/manufacturer of the GPU. * ​api ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The graphics/compute API supported by the GPU. * ​arch\_name ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): The architecture name of the GPU. * ​compute ([`Float32`](/mojo/stdlib/builtin/simd/#float32)): Compute capability version number. * ​version ([`StringSlice`](/mojo/stdlib/collections/string/string_slice/StringSlice)): Version string of the GPU architecture. * ​sm\_count ([`Int`](/mojo/stdlib/builtin/int/Int)): Number of streaming multiprocessors. **Returns:** `Self`: A fully configured GPUInfo instance. ### `write_to` `write_to(self, mut writer: T)` Writes GPU information to a writer. Outputs all GPU specifications and capabilities to the provided writer in a human-readable format. **Args:** * ​writer (`T`): A Writer instance to output the GPU information. ### `__str__` `__str__(self) -> String` Returns a string representation of the GPU information. Converts all GPU specifications and capabilities to a human-readable string format. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String containing all GPU information.
--- ## Vendor (Info)
`@register_passable(trivial)` `struct Vendor` Represents GPU vendors. This struct provides identifiers for different GPU vendors and utility methods for comparison and string representation. The Vendor struct defines constants for common GPU vendors (NVIDIA, AMD) and includes a NO\_GPU option for systems without GPU support. It provides comparison operators and string conversion methods for vendor identification. ## Implemented traits [`AnyType`](/mojo/stdlib/builtin/anytype/AnyType), [`Copyable`](/mojo/stdlib/builtin/value/Copyable), [`Identifiable`](/mojo/stdlib/builtin/identifiable/Identifiable), [`ImplicitlyCopyable`](/mojo/stdlib/builtin/value/ImplicitlyCopyable), [`Movable`](/mojo/stdlib/builtin/value/Movable), [`UnknownDestructibility`](/mojo/stdlib/builtin/anytype/UnknownDestructibility), [`Writable`](/mojo/stdlib/io/write/Writable) ## `comptime` members ### `__copyinit__is_trivial` `comptime __copyinit__is_trivial = True` ### `__del__is_trivial` `comptime __del__is_trivial = True` ### `__moveinit__is_trivial` `comptime __moveinit__is_trivial = True` ### `AMD_GPU` `comptime AMD_GPU = Vendor(1)` Represents AMD GPU vendor. ### `APPLE_GPU` `comptime APPLE_GPU = Vendor(3)` Represents Apple GPU vendor. ### `NO_GPU` `comptime NO_GPU = Vendor(0)` Represents no GPU or CPU-only execution. ### `NVIDIA_GPU` `comptime NVIDIA_GPU = Vendor(2)` Represents NVIDIA GPU vendor. ## Methods ### `__eq__` `__eq__(self, other: Self) -> Bool` Checks if two `Vendor` instances are equal. **Args:** * ​other (`Self`): The `Vendor` to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if vendors are equal, False otherwise. ### `__ne__` `__ne__(self, other: Self) -> Bool` Checks if two `Vendor` instances are not equal. **Args:** * ​other (`Self`): The `Vendor` to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if vendors are not equal, False otherwise. ### `__is__` `__is__(self, other: Self) -> Bool` Identity comparison for vendors. **Args:** * ​other (`Self`): The `Vendor` to compare with. **Returns:** [`Bool`](/mojo/stdlib/builtin/bool/Bool): True if vendors are identical, False otherwise. ### `write_to` `write_to(self, mut writer: T)` Writes vendor information to a writer. **Args:** * ​writer (`T`): The writer to output vendor information to. ### `__str__` `__str__(self) -> String` Returns a string representation of the vendor. **Returns:** [`String`](/mojo/stdlib/collections/string/string/String): String representation of the vendor.
--- ## info
Contains information about GPU architectures and their capabilities. This module provides detailed specifications for various GPU models including NVIDIA and AMD GPUs. It includes information about compute capabilities, memory specifications, thread organization, and performance characteristics. # GPU Target Configuration Guide When adding support for a new GPU architecture, you must create a target configuration function that returns a `_TargetType`. This guide explains the components of the MLIR target configuration, with special focus on the `data_layout` string. ## MLIR Target Components Each GPU target function returns an MLIR `kgen.target` attribute with these fields: * **triple**: Target triple (e.g., "nvptx64-nvidia-cuda", "amdgcn-amd-amdhsa"). * **arch**: Architecture name (e.g., "sm\_80", "gfx942", "apple-m4"). * **features**: Target-specific features (e.g., "+ptx81,+sm\_80"). * **tune\_cpu**: Optimization target (usually same as arch, can differ for tuning). * **data\_layout**: LLVM data layout string (explained in detail below). * **index\_bit\_width**: Bit width for index types (usually 64). * **simd\_bit\_width**: SIMD register width (usually 128 for modern GPUs). ## Understanding Data Layout Strings The `data_layout` string describes memory layout characteristics for the target architecture. It follows LLVM' data layout specification format: and is used by the compiler to make decisions about memory access patterns, type layouts, and optimizations. ### Format Overview The string consists of specifications separated by dashes (`-`): * **Endianness**: `e` (little-endian) or `E` (big-endian). * **Pointers**: `p[addr_space]:size:abi:pref:idx`. * **Integers**: `i::`. * **Floats**: `f::`. * **Vectors**: `v::`. * **Native widths**: `n::...`. * **Stack alignment**: `S`. * **Address space**: `A`. * **Mangling**: `m: