Skip to main content
Log in

Mojo function

get_fragment_size

get_fragment_size[mma_shape: Index[3]]() -> Index[3]

Calculates the fragment size per thread for a given MMA shape.

For tensor core operations, each thread in a warp handles a portion of the computation. This function determines how many elements each thread needs to process for the A, B, and C/D matrices based on the MMA shape.

Parameters:

  • mma_shape (Index[3]): An IndexList[3] containing the MMA dimensions [M, N, K].

Returns:

An IndexList[3] containing the fragment sizes per thread for matrices A, B, and C/D respectively, calculated as: [M*K/WARP_SIZE, N*K/WARP_SIZE, M*N/WARP_SIZE].