Skip to main content
Log in

Mojo struct

WGMMADescriptor

@register_passable(trivial) struct WGMMADescriptor[dtype: DType]

Descriptor for shared memory operands used in warp group matrix multiply operations.

This struct represents a descriptor that encodes information about shared memory layout and access patterns for warp group matrix multiply operations. The descriptor contains the following bit fields:

  • Start address (14 bits): Base address in shared memory.
  • Leading byte offset (14 bits): Leading dimension stride in bytes.
  • Stride byte offset (14 bits): Stride dimension offset in bytes.
  • Base offset (3 bits): Additional offset.
  • Swizzle mode (2 bits): Memory access pattern.

The bit layout is: +----------+----+------------+----+------------+----+-----+----------+-----+ | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| +----------+----+------------+----+------------+----+-----+----------+-----+ | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| +----------+----+------------+----+------------+----+-----+----------+-----+ | BaseAddr | 0 |LeadingDim | 0 | Stride | 0 |Offst| 0 |Swzle| +----------+----+------------+----+------------+----+-----+----------+-----+

See: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor

Fields

  • desc (SIMD[int64, 1]):

Implemented traits

AnyType, UnknownDestructibility

Methods

__init__

@implicit __init__(val: SIMD[int64, 1]) -> Self

Initialize descriptor with raw 64-bit value.

__add__

__add__(self, offset: Int) -> Self

Add offset to descriptor's base address.

Args:

  • offset (Int): Byte offset to add to base address.

Returns:

New descriptor with updated base address.

__iadd__

__iadd__(mut self, offset: Int)

Add offset to descriptor's base address in-place.

Args:

  • offset (Int): Byte offset to add to base address.

create

static create[stride_byte_offset: Int, leading_byte_offset: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle(__init__[__mlir_type.!kgen.int_literal](0))](smem_ptr: UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3)]) -> Self

Create a descriptor for shared memory operand.

Parameters:

  • stride_byte_offset (Int): Stride dimension offset in bytes.
  • leading_byte_offset (Int): Leading dimension stride in bytes.
  • swizzle_mode (TensorMapSwizzle): Memory access pattern mode.

Args:

  • smem_ptr (UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3)]): Pointer to shared memory operand.

Returns:

Initialized descriptor for the shared memory operand.