Skip to main content

Mojo struct

RaggedTMA3DTile

struct RaggedTMA3DTile[dtype: DType, swizzle_mode: TensorMapSwizzle, BM: Int, BN: Int]

Creates a TMA descriptor for loading/storing from ragged 3D arrays with a ragged leading dimension. This loads 2D tiles, indexing into the middle dim. When using this loads, it is essential that at least BM * stride space has been allocated in front of the gmem pointer, otherwise CUDA_ERROR_ILLEGAL_ADDRESS may result.

Parameters

  • dtype (DType): The data type of the tensor.
  • swizzle_mode (TensorMapSwizzle): The swizzling mode to use for memory access.
  • BM (Int): The number of rows of the corresponding 2D shared memory tile.
  • BN (Int): The number of columns of the corresponding 2D shared memory tile.

Fields

  • descriptor (TMADescriptor): The TMA descriptor that will be used to store the ragged tensor.

Implemented traits

AnyType, Copyable, DevicePassable, ImplicitlyCopyable, ImplicitlyDestructible, Movable

comptime members

__copyinit__is_trivial

comptime __copyinit__is_trivial = False

__del__is_trivial

comptime __del__is_trivial = True

__moveinit__is_trivial

comptime __moveinit__is_trivial = True

device_type

comptime device_type = RaggedTMA3DTile[dtype, swizzle_mode, BM, BN]

The device-side type representation.

layout

comptime layout = tile_layout_k_major[dtype, BM, BN, swizzle_mode]()

The unswizzled-smem layout copied to/from by this tma op.

swizzle_granularity

comptime swizzle_granularity = (swizzle_mode.bytes() // size_of[dtype]())

The number of columns that must be copied at a time due to the swizzle size.

Methods

__init__

@implicit __init__(out self, descriptor: TMADescriptor)

Initializes a new TMATensorTile with the provided TMA descriptor.

Args:

  • descriptor (TMADescriptor): The TMA descriptor that defines the memory access pattern.

__copyinit__

__copyinit__(out self, other: Self)

Copy initializes this RaggedTMA3DTile from another instance.

Args:

  • other (Self): The other RaggedTMA3DTile instance to copy from.

get_type_name

static get_type_name() -> String

Returns a string representation of the RaggedTMA3DTile type.

Returns:

String: A string containing the type name with all template parameters.

create

static create[*, depth: Int = BN](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], origin], *, rows: Int, middle_dim: Int) -> Self

Create a RaggedTMA3DTile.

Parameters:

  • depth (Int): The size of the inner-most, contiguous, dimension.

Args:

  • ctx (DeviceContext): The device context used to create the TMA descriptors.
  • ptr (UnsafePointer): The global memory pointer.
  • rows (Int): The size of the ragged dimension.
  • middle_dim (Int): The size of the middle dimension.

Returns:

Self: A RaggedTMA3DTile corresponding to the gmem.

Raises:

If TMA descriptor creation fails.

get_device_type_name

static get_device_type_name() -> String

Gets device_type's name, for use in error messages when handing arguments to kernels.

Returns:

String: This type's name.

async_copy_to

async_copy_to[cta_group: Int = 1](self, dst: UnsafePointer[Scalar[dtype], origin, address_space=AddressSpace.SHARED], ref [3] mem_barrier: SharedMemBarrier, *, ragged_idx: UInt32, dynamic_dim: UInt32, middle_idx: UInt32)

Copy from the RaggedTMA3DTile source to the smem destination.

Parameters:

  • cta_group (Int): If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.

Args:

  • dst (UnsafePointer): The destination shared memory pointer to which we copy memory.
  • mem_barrier (SharedMemBarrier): The memory barrier used to track and synchronize the asynchronous transfer.
  • ragged_idx (UInt32): Index into the ragged dimension.
  • dynamic_dim (UInt32): Number of rows to copy.
  • middle_idx (UInt32): Index into the middle (generally head) dimension.

async_copy_from

async_copy_from(self, src: UnsafePointer[Scalar[dtype], origin, address_space=AddressSpace.SHARED], *, ragged_idx: UInt32, dynamic_dim: UInt32, middle_idx: UInt32)

Copy from the smem source to the RaggedTMA3DTile destination.

Args:

  • src (UnsafePointer): The source shared memory pointer from which we copy memory.
  • ragged_idx (UInt32): Index into the ragged dimension.
  • dynamic_dim (UInt32): Number of rows to copy.
  • middle_idx (UInt32): Index into the middle (generally head) dimension.

prefetch_descriptor

prefetch_descriptor(self)

Prefetches the TMA descriptor into cache.

Was this page helpful?