Skip to main content

Mojo struct

TiledMmaOp

struct TiledMmaOp[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool = False]

TileTensor-native MMA operation for AMD attention kernels.

Wraps the raw GPU MMA intrinsic and operates directly on TileTensor register tiles.

Parameters

  • out_type (DType): Accumulator data type.
  • in_type (DType): Input matrix element data type.
  • shape (IndexList): MMA instruction shape [M, N, K].
  • group_size (Int): Number of MMA operations along the K dimension.
  • transpose_b (Bool): Whether to transpose the B matrix.

Implemented traits

AnyType, ImplicitlyDestructible

comptime members

a_frag_size

comptime a_frag_size = num_matrix_reg[shape[0], shape[2]]()

b_frag_size

comptime b_frag_size = num_matrix_reg[shape[2], shape[1]]()

c_frag_size

comptime c_frag_size = num_matrix_reg[shape[0], shape[1]]()

Methods

mma

static mma[swap_a_b: Bool = False](a: TileTensor[a.dtype, a.LayoutType, a.origin, address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b.dtype, b.LayoutType, b.origin, address_space=AddressSpace.LOCAL, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[c.dtype, c.LayoutType, c.origin, address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size])

Perform grouped MMA on TileTensor operands.

Tiles down to individual MMA fragments so the compiler can prove static shapes, then calls the raw gpu_mma intrinsic directly.

Parameters:

  • swap_a_b (Bool): Whether to swap A and B operands.

Args:

  • a (TileTensor): A operand tile [num_m_mmas, group_size * a_frag_size].
  • b (TileTensor): B operand tile [num_n_mmas, group_size * b_frag_size].
  • c (TileTensor): Accumulator tile [num_m_mmas * num_n_mmas, c_frag_size], modified in-place.

load_b

static load_b[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, warp_tile.LayoutType, warp_tile.origin, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, reg_tile.LayoutType, reg_tile.origin, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: UInt = 0)

Load B-matrix fragments from SMEM to registers.

Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Handles both transposed and non-transposed B layouts via comptime dispatch.

Parameters:

  • swizzle (Optional): Optional swizzle for SMEM bank-conflict avoidance.

Args:

  • warp_tile (TileTensor): Source warp tile in shared memory.
  • reg_tile (TileTensor): Destination register tile for MMA fragments.
  • k_group_idx (UInt): K-dimension group index within the warp tile.

load_a

static load_a[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, warp_tile.LayoutType, warp_tile.origin, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, reg_tile.LayoutType, reg_tile.origin, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: UInt = 0)

Load A-matrix fragments from SMEM to registers.

Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Always uses col_major thread distribution.

Parameters:

  • swizzle (Optional): Optional swizzle for SMEM access.

Args:

  • warp_tile (TileTensor): Source warp tile in shared memory.
  • reg_tile (TileTensor): Destination register tile for MMA fragments.
  • k_group_idx (UInt): K-dimension group index within the warp tile.

Was this page helpful?