Skip to main content

Mojo struct

TiledMmaOp

struct TiledMmaOp[out_type: DType, in_type: DType, shape: IndexList[3], transpose_b: Bool = False]

TileTensor-native MMA operation for AMD attention kernels.

Wraps the raw GPU MMA intrinsic and operates directly on TileTensor register tiles.

Parameters​

  • ​out_type (DType): Accumulator data type.
  • ​in_type (DType): Input matrix element data type.
  • ​shape (IndexList[3]): MMA instruction shape [M, N, K].
  • ​transpose_b (Bool): Whether to transpose the B matrix.

Implemented traits​

AnyType, ImplicitlyDestructible

comptime members​

a_frag_size​

comptime a_frag_size = num_matrix_reg[shape[0], shape[2]]()

b_frag_size​

comptime b_frag_size = num_matrix_reg[shape[2], shape[1]]()

c_frag_size​

comptime c_frag_size = num_matrix_reg[shape[0], shape[1]]()

Methods​

mma​

static mma[swap_a_b: Bool = False](a: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size])

Perform MMA on TileTensor operands.

Tiles down to individual MMA fragments so the compiler can prove static shapes, then calls the raw gpu_mma intrinsic directly.

Parameters:

  • ​swap_a_b (Bool): Whether to swap A and B operands. Only controls the argument order of gpu_mma; accumulator indexing is always col-major over (M, N).

Args:

load_b​

static load_b[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: Int = 0)

Load B-matrix fragments from SMEM to registers.

Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Handles both transposed and non-transposed B layouts via comptime dispatch.

Parameters:

  • ​swizzle (Optional[Swizzle]): Optional swizzle for SMEM bank-conflict avoidance.

Args:

load_a​

static load_a[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: Int = 0)

Load A-matrix fragments from SMEM to registers.

Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Always uses col_major thread distribution.

Parameters:

Args: