Skip to main content

Mojo function

warp_specialized_matmul

warp_specialized_matmul[M: Int, N: Int, K: Int, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, a_producer_warps: Int, b_producer_warps: Int, consumer_warps: Int, pipeline_stages: Int = 1](a_tt: TileTensor[DType.bfloat16, a_tt.LayoutType, a_tt.origin, address_space=a_tt.address_space, linear_idx_type=a_tt.linear_idx_type, element_size=a_tt.element_size], b_tt: TileTensor[DType.bfloat16, b_tt.LayoutType, b_tt.origin, address_space=b_tt.address_space, linear_idx_type=b_tt.linear_idx_type, element_size=b_tt.element_size], c_tt: TileTensor[DType.float32, c_tt.LayoutType, c_tt.origin, address_space=c_tt.address_space, linear_idx_type=c_tt.linear_idx_type, element_size=c_tt.element_size], ctx: DeviceContext)

Was this page helpful?