Mojo function
band_part
band_part(input: Symbol, num_lower: Symbol, num_upper: Symbol, exclude: Bool = False) -> Symbol
Masks out everything except a diagonal band of an input matrix.
Copies a tensor setting everything outside the central diagonal band of the matricies to zero, where all but the last two axes are effectively batches, and the last two axes define sub matricies.
Assumes the input has dimensions [I, J, ..., M, N], then the output tensor has the same shape as the input, and the values are given by
out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n].
out[i, j, ..., m, n] = in_band(m, n) * input[i, j, ..., m, n].
with the indicator function:
in_band(m, n) = ((num_lower < 0 || (m - n) <= num_lower)) &&
(num_upper < 0 || (n - m) <= num_upper))
in_band(m, n) = ((num_lower < 0 || (m - n) <= num_lower)) &&
(num_upper < 0 || (n - m) <= num_upper))
Args:
- input (
Symbol
): The input to mask out. - num_lower (
Symbol
): The number of diagonal bands to include below the central diagonal. If -1, include the entire lower triangle. - num_upper (
Symbol
): The number of diagonal bands to include above the central diagonal. If -1, include the entire upper triangle. - exclude (
Bool
): If true, invert the selection of elements to mask. Elements in the band are set to zero.
Returns:
A symbolic tensor value with the configured selection masked out to 0 values, and the remaining values copied from the input tensor.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!