Skip to main content

Mojo struct

TiledTensorCore

struct TiledTensorCore[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool = False]

TiledTensorCore provides a wrapper around TensorCore to support multiple MMAs along the K dimension.

Enables larger K dimension operations by decomposing them into multiple smaller MMA operations. Currently only being used for AMD GPUs to enable 16x16x32 operations using two 16x16x16 MMAs.

Parameters

  • out_type (DType): The data type for output/accumulation operations.
  • in_type (DType): The data type for input matrix elements.
  • shape (IndexList): The shape parameters for individual MMA operations [M, N, K].
  • group_size (Int): Number of MMA operations along the K dimension.
  • transpose_b (Bool): Whether to transpose the b matrix. Defaults to False.

Implemented traits

AnyType, UnknownDestructibility

Aliases

__del__is_trivial

alias __del__is_trivial = True

mma_op

alias mma_op = TensorCore[out_type, in_type, shape, transpose_b]()

Methods

mma

static mma[swap_a_b: Bool = False](a_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])

Perform multiple matrix multiply-accumulate operations along the K dimension.

Executes group_size MMA operations, processing slices of the K dimension and accumulating results in c_reg_tile.

Parameters:

  • swap_a_b (Bool): Whether to swap a and b operands. Defaults to False.

Args:

  • a_reg_tile (LayoutTensor): Input matrix a fragments [num_m_mmas, group_size * a_frag_size].
  • b_reg_tile (LayoutTensor): Input matrix b fragments [num_n_mmas, group_size * b_frag_size].
  • c_reg_tile (LayoutTensor): Accumulation matrix c fragments, modified in-place.

Was this page helpful?