Skip to main content

function

batch_matmul

batch_matmul(lhs: Symbol, rhs: Symbol) -> Symbol

Computes the matrix multiplication of two symbolic tensors.

The last two dimensions of each tensor are treated as matricies and multiplied, and the remaining dimensions are broadcast dimensions.

This supports arbitrary-rank rhs inputs, but may be less performant than matmul_by_matrix if rhs is rank 2.

Args:

  • lhs (Symbol): The left-hand-side of the matmul.
  • rhs (Symbol): The right-hand-side of the matmul.

Returns:

A symbolic tensor representing he result of broadcasting the two matricies together according to matmul_broadcast and then performing a matrix multiply along the last two dimension of each tensor.