Mojo function
get_fragment_size
get_fragment_size[mma_shape: IndexList[3]]() -> IndexList[3]
Calculates the fragment size per thread for a given MMA shape.
For tensor core operations, each thread in a warp handles a portion of the computation. This function determines how many elements each thread needs to process for the A, B, and C/D matrices based on the MMA shape.
Parameters:
- mma_shape (
IndexList): AnIndexList[3]containing the MMA dimensions [M, N, K]. 
Returns:
IndexList: An IndexList[3] containing the fragment sizes per thread for matrices
A, B, and C/D respectively, calculated as:
[M*K/WARP_SIZE, N*K/WARP_SIZE, M*N/WARP_SIZE].
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!