Mojo struct
WGMMADescriptor
@register_passable(trivial)
struct WGMMADescriptor[dtype: DType]
Descriptor for shared memory operands used in warp group matrix multiply operations.
This struct represents a descriptor that encodes information about shared memory layout and access patterns for warp group matrix multiply operations. The descriptor contains the following bit fields:
- Start address (14 bits): Base address in shared memory.
- Leading byte offset (14 bits): Leading dimension stride in bytes.
- Stride byte offset (14 bits): Stride dimension offset in bytes.
- Base offset (3 bits): Additional offset.
- Swizzle mode (2 bits): Memory access pattern.
The bit layout is: +----------+----+------------+----+------------+----+-----+----------+-----+ | 0-13 |14-15| 16-29 |30-31| 32-45 |46-48|49-51| 52-61 |62-63| +----------+----+------------+----+------------+----+-----+----------+-----+ | 14bits |2bits| 14bits |2bits| 14bits |2bits|3bits| 10bits |2bits| +----------+----+------------+----+------------+----+-----+----------+-----+ | BaseAddr | 0 |LeadingDim | 0 | Stride | 0 |Offst| 0 |Swzle| +----------+----+------------+----+------------+----+-----+----------+-----+
Fields
- desc (
SIMD[int64, 1]
):
Implemented traits
AnyType
,
UnknownDestructibility
Methods
__init__
@implicit
__init__(val: SIMD[int64, 1]) -> Self
Initialize descriptor with raw 64-bit value.
__add__
__add__(self, offset: Int) -> Self
Add offset to descriptor's base address.
Args:
- offset (
Int
): Byte offset to add to base address.
Returns:
New descriptor with updated base address.
__iadd__
__iadd__(mut self, offset: Int)
Add offset to descriptor's base address in-place.
Args:
- offset (
Int
): Byte offset to add to base address.
create
static create[stride_byte_offset: Int, leading_byte_offset: Int, swizzle_mode: TensorMapSwizzle = TensorMapSwizzle(__init__[__mlir_type.!kgen.int_literal](0))](smem_ptr: UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3)]) -> Self
Create a descriptor for shared memory operand.
Parameters:
- stride_byte_offset (
Int
): Stride dimension offset in bytes. - leading_byte_offset (
Int
): Leading dimension stride in bytes. - swizzle_mode (
TensorMapSwizzle
): Memory access pattern mode.
Args:
- smem_ptr (
UnsafePointer[SIMD[dtype, 1], address_space=AddressSpace(3)]
): Pointer to shared memory operand.
Returns:
Initialized descriptor for the shared memory operand.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!