Mojo struct
TiledTensorCore
struct TiledTensorCore[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool = False]
TiledTensorCore provides a wrapper around TensorCore to support multiple MMAs along the K dimension.
Enables larger K dimension operations by decomposing them into multiple smaller MMA operations. Currently only being used for AMD GPUs to enable 16x16x32 operations using two 16x16x16 MMAs.
Parameters
- out_type (
DType
): The data type for output/accumulation operations. - in_type (
DType
): The data type for input matrix elements. - shape (
IndexList
): The shape parameters for individual MMA operations [M, N, K]. - group_size (
Int
): Number of MMA operations along the K dimension. - transpose_b (
Bool
): Whether to transpose the b matrix. Defaults to False.
Implemented traits
AnyType
,
UnknownDestructibility
Aliases
__del__is_trivial
alias __del__is_trivial = True
mma_op
alias mma_op = TensorCore[out_type, in_type, shape, transpose_b]()
Methods
mma
static mma[swap_a_b: Bool = False](a_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], b_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], c_reg_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment])
Perform multiple matrix multiply-accumulate operations along the K dimension.
Executes group_size MMA operations, processing slices of the K dimension and accumulating results in c_reg_tile.
Parameters:
- swap_a_b (
Bool
): Whether to swap a and b operands. Defaults to False.
Args:
- a_reg_tile (
LayoutTensor
): Input matrix a fragments [num_m_mmas, group_size * a_frag_size]. - b_reg_tile (
LayoutTensor
): Input matrix b fragments [num_n_mmas, group_size * b_frag_size]. - c_reg_tile (
LayoutTensor
): Accumulation matrix c fragments, modified in-place.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!