Skip to main content

function

matmul_broadcast

matmul_broadcast(lhs: Symbol, rhs: Symbol) -> List[Symbol]

Computes the broadcasting of two symbolic tensors for a matmul.

Args:

  • lhs (Symbol): The left side of the matmul.
  • rhs (Symbol): The right side of the matmul.

Returns:

A pair of symbolic tensors corresponding to the lhs and rhs respectively, after being broadcast to the right shapes to perform a matmul between them. All but the final two dimensions are broadcasted.