Skip to main content
Log in

Mojo function

repeat_interleave

repeat_interleave(input: Symbol, repeats: Int, dim: Optional[Int] = #kgen.none) -> Symbol

Repeats elements of a tensor along the given dimension.

Modeled after torch.repeat_interleave, with the constraint that Tensor-valued repeats are not yet supported.

For example, given repeats=2 and the following input:

input = max.tensor.Tensor[DType.float32](
max.tensor.TensorShape(2, 2),
1.0, 2.0,
3.0, 4.0,
)
input = max.tensor.Tensor[DType.float32](
max.tensor.TensorShape(2, 2),
1.0, 2.0,
3.0, 4.0,
)

repeat_interleave with dim=0:

output = max.tensor.Tensor[DType.float32](
max.tensor.TensorShape(4, 2),
1.0, 2.0,
1.0, 2.0,
3.0, 4.0,
3.0, 4.0,
)
output = max.tensor.Tensor[DType.float32](
max.tensor.TensorShape(4, 2),
1.0, 2.0,
1.0, 2.0,
3.0, 4.0,
3.0, 4.0,
)

repeat_interleave with dim=1:

output = max.tensor.Tensor[DType.float32](
max.tensor.TensorShape(2, 4),
1.0, 1.0, 2.0, 2.0,
3.0, 3.0, 4.0, 4.0,
)
output = max.tensor.Tensor[DType.float32](
max.tensor.TensorShape(2, 4),
1.0, 1.0, 2.0, 2.0,
3.0, 3.0, 4.0, 4.0,
)

repeat_interleave with dim=None (the default):

output = max.tensor.Tensor[DType.float32](
max.tensor.TensorShape(8),
1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0,
)
output = max.tensor.Tensor[DType.float32](
max.tensor.TensorShape(8),
1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0,
)

Args:

  • input (Symbol): The input tensor.
  • repeats (Int): The number of repetitions for each element.
  • dim (Optional[Int]): The dimension along which to repeat values. By default (or if dim is None), flatten the input array.

Returns:

A symbolic tensor with the elements interleaved.

Was this page helpful?