Skip to main content

function

matmul_by_matrix

matmul_by_matrix(lhs: Symbol, rhs: Symbol) -> Symbol

Computes the matrix multiplication of two symbolic tensors.

The last two dimensions in lhs are treated as matricies and multiplied by rhs (which must be a 2D tensor). Any remaining dimensions in lhs are broadcast dimensions.

Args:

  • lhs (Symbol): The left-hand-side of the matmul.
  • rhs (Symbol): The right-hand-side of the matmul. Must be rank 2 (a 2D tensor/matrix).

Returns:

A symbolic tensor representing he result of broadcasting the two matricies together according to matmul_broadcast and then performing a matrix multiply along the last two dimension of each tensor.