Skip to main content

Mojo function

warp_specialized_matmul

warp_specialized_matmul[M: Int, N: Int, K: Int, BM: Int, BN: Int, BK: Int, WM: Int, WN: Int, WK: Int, a_producer_warps: Int, b_producer_warps: Int, consumer_warps: Int, pipeline_stages: Int = 1](a_device_tensor: LayoutTensor[DType.bfloat16, Layout.row_major(VariadicList(M, K)), a_device_tensor.origin, address_space=a_device_tensor.address_space, element_layout=a_device_tensor.element_layout, layout_int_type=a_device_tensor.layout_int_type, linear_idx_type=a_device_tensor.linear_idx_type, masked=a_device_tensor.masked, alignment=a_device_tensor.alignment], b_device_tensor: LayoutTensor[DType.bfloat16, Layout.row_major(VariadicList(N, K)), b_device_tensor.origin, address_space=b_device_tensor.address_space, element_layout=b_device_tensor.element_layout, layout_int_type=b_device_tensor.layout_int_type, linear_idx_type=b_device_tensor.linear_idx_type, masked=b_device_tensor.masked, alignment=b_device_tensor.alignment], c_device_tensor: LayoutTensor[DType.float32, Layout.row_major(VariadicList(M, N)), c_device_tensor.origin, address_space=c_device_tensor.address_space, element_layout=c_device_tensor.element_layout, layout_int_type=c_device_tensor.layout_int_type, linear_idx_type=c_device_tensor.linear_idx_type, masked=c_device_tensor.masked, alignment=c_device_tensor.alignment], ctx: DeviceContext)

Was this page helpful?