IMPORTANT: To view this page as Markdown, append `.md` to the URL (e.g. /max/get-started.md). For the complete documentation index, see llms.txt.
Skip to main content
For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).

Mojo struct

TiledMmaOp

struct TiledMmaOp[out_type: DType, in_type: DType, shape: IndexList[Int(3)], transpose_b: Bool = False]

TileTensor-native MMA operation for AMD attention kernels.

Wraps the raw GPU MMA intrinsic and operates directly on TileTensor register tiles.

Parameters​

  • ​out_type (DType): Accumulator data type.
  • ​in_type (DType): Input matrix element data type.
  • ​shape (IndexList[Int(3)]): MMA instruction shape [M, N, K].
  • ​transpose_b (Bool): Whether to transpose the B matrix.

Implemented traits​

AnyType, ImplicitlyDeletable

comptime members​

a_frag_size​

comptime a_frag_size = num_matrix_reg[shape[Int(0)], shape[Int(2)]]()

b_frag_size​

comptime b_frag_size = num_matrix_reg[shape[Int(2)], shape[Int(1)]]()

c_frag_size​

comptime c_frag_size = num_matrix_reg[shape[Int(0)], shape[Int(1)]]()

Methods​

mma​

static def mma[swap_a_b: Bool = False](a: TileTensor[Storage=a.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[Storage=b.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[Storage=c.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=c.linear_idx_type, element_size=c.element_size])

Perform MMA on TileTensor operands.

Tiles down to individual MMA fragments so the compiler can prove static shapes, then calls the raw gpu_mma intrinsic directly.

Parameters:

  • ​swap_a_b (Bool): Whether to swap A and B operands. Only controls the argument order of gpu_mma; accumulator indexing is always col-major over (M, N).

Args:

load_b​

static def load_b[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, Storage=warp_tile.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, Storage=reg_tile.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: Int = Int(0))

Load B-matrix fragments from SMEM to registers.

Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Handles both transposed and non-transposed B layouts via comptime dispatch.

Parameters:

  • ​swizzle (Optional[Swizzle]): Optional swizzle for SMEM bank-conflict avoidance.

Args:

load_a​

static def load_a[swizzle: Optional[Swizzle] = None](warp_tile: TileTensor[in_type, Storage=warp_tile.Storage, address_space=AddressSpace.SHARED, linear_idx_type=warp_tile.linear_idx_type, element_size=warp_tile.element_size], reg_tile: TileTensor[in_type, Storage=reg_tile.Storage, address_space=AddressSpace.LOCAL, linear_idx_type=reg_tile.linear_idx_type, element_size=reg_tile.element_size], k_group_idx: Int = Int(0))

Load A-matrix fragments from SMEM to registers.

Distributes the warp tile across threads with optional swizzle, loading one MMA fragment per iteration. Always uses col_major thread distribution.

Parameters:

Args: