Skip to main content

function

band_part

band_part(input: Symbol, num_lower: Symbol, num_upper: Symbol, exclude: Bool = 0) -> Symbol

Masks out everything except a diagonal band of an input matrix.

Copies a tensor setting everything outside the central diagonal band of the matricies to zero, where all but the last two axes are effectively batches, and the last two axes define sub matricies.

Assumes the input has dimensions [I, J, ..., M, N], then the output tensor has the same shape as the input, and the values are given by

out[i, j, ..., m, n] = in_band(m, n) * input[i, j,  ..., m, n].

with the indicator function:

in_band(m, n) = ((num_lower < 0 || (m - n) <= num_lower)) &&
(num_upper < 0 || (n - m) <= num_upper))

Args:

  • input (Symbol): The input to mask out.
  • num_lower (Symbol): The number of diagonal bands to include below the central diagonal. If -1, include the entire lower triangle.
  • num_upper (Symbol): The number of diagonal bands to include above the central diagonal. If -1, include the entire upper triangle.
  • exclude (Bool): If true, invert the selection of elements to mask. Elements in the band are set to zero.

Returns:

A symbolic tensor value with the configured selection masked out to 0 values, and the remaining values copied from the input tensor.