Mojo function
get_fragment_size
get_fragment_size[mma_shape: Index[3]]() -> Index[3]
Calculates the fragment size per thread for a given MMA shape.
For tensor core operations, each thread in a warp handles a portion of the computation. This function determines how many elements each thread needs to process for the A, B, and C/D matrices based on the MMA shape.
Parameters:
- mma_shape (
Index[3]
): AnIndexList[3]
containing the MMA dimensions [M, N, K].
Returns:
An IndexList[3]
containing the fragment sizes per thread for matrices
A, B, and C/D respectively, calculated as:
[M*K/WARP_SIZE, N*K/WARP_SIZE, M*N/WARP_SIZE]
.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!