Mojo struct
ScalesTileLoader
@register_passable(trivial)
struct ScalesTileLoader[tma_origin: ImmutOrigin, dtype: DType, gmem_layout: Layout, desc_layout: Layout, /, *, cta_group: Int]
TMA-based scales tile loader for blockwise FP8.
Unlike TileLoaderTMA, this loader:
- Uses async_copy (no multicast) since scales aren't distributed across CTAs
- Uses (row_coord, k_coord) coordinate order matching scales tensor layout
Parameters
- tma_origin (
ImmutOrigin): Origin of the TMA descriptor pointer. - dtype (
DType): Element data type (typically float8 for scales). - gmem_layout (
Layout): Global memory tensor layout. - desc_layout (
Layout): TMA descriptor layout (tile dimensions). - cta_group (
Int): CTA group size (1 or 2 for SM100 2-SM MMA).
Fields
- tma_op (
ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOpPtr):
Implemented traits
AnyType,
Copyable,
ImplicitlyCopyable,
ImplicitlyDestructible,
Movable
comptime members
__copyinit__is_trivial
comptime __copyinit__is_trivial = True
__del__is_trivial
comptime __del__is_trivial = True
__moveinit__is_trivial
comptime __moveinit__is_trivial = True
TmaOp
comptime TmaOp = TMATensorTile[dtype, gmem_layout, desc_layout]
TmaOpPtr
comptime TmaOpPtr = Pointer[ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]
Methods
__init__
__init__(tma_op: Pointer[ScalesTileLoader[tma_origin, dtype, gmem_layout, desc_layout, cta_group=cta_group].TmaOp, tma_origin]) -> Self
Initialize the scales tile loader.
Args:
- tma_op (
Pointer): Pointer to TMA descriptor (grid constant).
load
load[tile_layout: Layout, /, alignment: Int = 128](self, dest: LayoutTensor[dtype, tile_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, alignment=alignment], ref [3] barrier: SharedMemBarrier, row_coord: Int, k_coord: Int)
Load a scales tile using TMA hardware acceleration.
Issues an async copy from global memory to shared memory. Unlike TileLoaderTMA, this uses (row_coord, k_coord) order matching the scales tensor layout.
Args:
- dest (
LayoutTensor): Destination SMEM tile. - barrier (
SharedMemBarrier): Memory barrier for TMA completion signaling. - row_coord (
Int): Row coordinate (M for A-scales) in global memory. - k_coord (
Int): K dimension coordinate in global memory.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!