Mojo struct
RaggedTMA3DTile
struct RaggedTMA3DTile[dtype: DType, swizzle_mode: TensorMapSwizzle, BM: Int, BN: Int]
Creates a TMA descriptor for loading/storing from ragged 3D arrays with a ragged leading dimension. This loads 2D tiles, indexing into the middle dim. When using this loads, it is essential that at least BM * stride space has been allocated in front of the gmem pointer, otherwise CUDA_ERROR_ILLEGAL_ADDRESS may result.
Parameters
- dtype (
DType): The data type of the tensor. - swizzle_mode (
TensorMapSwizzle): The swizzling mode to use for memory access. - BM (
Int): The number of rows of the corresponding 2D shared memory tile. - BN (
Int): The number of columns of the corresponding 2D shared memory tile.
Fields
- descriptor (
TMADescriptor): The TMA descriptor that will be used to store the ragged tensor.
Implemented traits
AnyType,
Copyable,
DevicePassable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = False
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
device_type
comptime device_type = RaggedTMA3DTile[dtype, swizzle_mode, BM, BN]
The device-side type representation.
layout
comptime layout = tile_layout_k_major[dtype, BM, BN, swizzle_mode]()
The unswizzled-smem layout copied to/from by this tma op.
swizzle_granularity
comptime swizzle_granularity = (swizzle_mode.bytes() // size_of[dtype]())
The number of columns that must be copied at a time due to the swizzle size.
Methods
__init__
@implicit
__init__(out self, descriptor: TMADescriptor)
Initializes a new TMATensorTile with the provided TMA descriptor.
Args:
- descriptor (
TMADescriptor): The TMA descriptor that defines the memory access pattern.
__copyinit__
__copyinit__(out self, other: Self)
Copy initializes this RaggedTMA3DTile from another instance.
Args:
- other (
Self): The otherRaggedTMA3DTileinstance to copy from.
get_type_name
static get_type_name() -> String
Returns a string representation of the RaggedTMA3DTile type.
Returns:
String: A string containing the type name with all template parameters.
create
static create[*, depth: Int = BN](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], origin], *, rows: Int, middle_dim: Int) -> Self
Create a RaggedTMA3DTile.
Parameters:
- depth (
Int): The size of the inner-most, contiguous, dimension.
Args:
- ctx (
DeviceContext): The device context used to create the TMA descriptors. - ptr (
UnsafePointer): The global memory pointer. - rows (
Int): The size of the ragged dimension. - middle_dim (
Int): The size of the middle dimension.
Returns:
Self: A RaggedTMA3DTile corresponding to the gmem.
Raises:
If TMA descriptor creation fails.
get_device_type_name
static get_device_type_name() -> String
Gets device_type's name, for use in error messages when handing arguments to kernels.
Returns:
String: This type's name.
async_copy_to
async_copy_to[cta_group: Int = 1](self, dst: UnsafePointer[Scalar[dtype], origin, address_space=AddressSpace.SHARED], ref [3] mem_barrier: SharedMemBarrier, *, ragged_idx: UInt32, dynamic_dim: UInt32, middle_idx: UInt32)
Copy from the RaggedTMA3DTile source to the smem destination.
Parameters:
- cta_group (
Int): If the TMA is issued with cta_group == 2, only the leader CTA needs to be notified upon completion.
Args:
- dst (
UnsafePointer): The destination shared memory pointer to which we copy memory. - mem_barrier (
SharedMemBarrier): The memory barrier used to track and synchronize the asynchronous transfer. - ragged_idx (
UInt32): Index into the ragged dimension. - dynamic_dim (
UInt32): Number of rows to copy. - middle_idx (
UInt32): Index into the middle (generally head) dimension.
async_copy_from
async_copy_from(self, src: UnsafePointer[Scalar[dtype], origin, address_space=AddressSpace.SHARED], *, ragged_idx: UInt32, dynamic_dim: UInt32, middle_idx: UInt32)
Copy from the smem source to the RaggedTMA3DTile destination.
Args:
- src (
UnsafePointer): The source shared memory pointer from which we copy memory. - ragged_idx (
UInt32): Index into the ragged dimension. - dynamic_dim (
UInt32): Number of rows to copy. - middle_idx (
UInt32): Index into the middle (generally head) dimension.
prefetch_descriptor
prefetch_descriptor(self)
Prefetches the TMA descriptor into cache.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!