Search code examples
pythonpytorchtensor

Constrain elements in a PyTorch tensor to be equal


I have a PyTorch tensor and would like to impose equality constraints on its elements while optimizing. An example tensor of 2 * 9 is shown below, where the same color indicates the elements should always be equal.

Example tensor

Let's make a minimal example of 1 * 4, and initialize the first two and last two elements to be equal respectively.

import torch
x1 = torch.tensor([1.2, 1.2, -0.3, -0.3], requires_grad=True)
print(x1)
# tensor([ 1.2000,  1.2000, -0.3000, -0.3000])

If I perform a simple least squares directly, the equality definitely exists no more.

y = torch.arange(4)
opt_1 = torch.optim.SGD([x1], lr=0.1)
opt_1.zero_grad()
loss = (y - x1).pow(2).sum()
loss.backward()
opt_1.step()
print(x1)
# tensor([0.9600, 1.1600, 0.1600, 0.3600], requires_grad=True)

I tried to express this tensor as a weighted sum of masks.

def weighted_sum(c, masks):
    return torch.sum(torch.stack([c[0] * masks[0], c[1] * masks[1]]), axis=0)

c = torch.tensor([1.2, -0.3], requires_grad=True)
masks = torch.tensor([[1, 1, 0, 0], [0, 0, 1, 1]])
x2 = weighted_sum(c, masks)
print(x2)
# tensor([ 1.2000,  1.2000, -0.3000, -0.3000])

In this way, the equality remains after optimization.

opt_c = torch.optim.SGD([c], lr=0.1)
opt_c.zero_grad()
y = torch.arange(4)
x2 = weighted_sum(c, masks)
loss = (y - x2).pow(2).sum()
loss.backward()
opt_c.step()
print(c)
# tensor([0.9200, 0.8200], requires_grad=True)
print(weighted_sum(c, masks))
# tensor([0.9200, 0.9200, 0.8200, 0.8200], grad_fn=<SumBackward1>)

However, the biggest issue of this solution is that I have to maintain a large set of masks when the input dimension is high; surely it will result in out of memory. Suppose the shape of input tensor is d_0 * d_1 * ... * d_m, and the number of equality blocks is k, then there will be a huge mask of shape k * d_0 * d_1 * ... * d_m, which is unacceptable.


Another solution might be upsampling the low resolution tensor like this one. However, it cannot be applied to irregular equality blocks, e.g.,

tensor([[ 1.2000,  1.2000,  1.2000, -3.1000, -3.1000],
        [-0.1000,  2.0000,  2.0000,  2.0000,  2.0000]])

So... is there a smarter way of implementing such equality constraints in a PyTorch tensor?


Solution

  • if you want them to always be equal, why not just remove both the first and last value from x and y? The extra values can be derived from the model output when needed after training, since they're expected to be equal to their neighbors anyway. There's no need to learn two copies of the same values.

    If you want a more approximate learning that they're the same, you could add some_weight * (torch.abs(x[0]-x[1]) + torch.abs(x[-1] - x[-2])) to your loss function. Then your loss would be trying to learn that these are expected to be the same.

    Or, instead of masks, if you have counts for each value, maybe you're looking for something like this?

    def convert(tensor, counts):
         return torch.cat( [v.repeat(count) for (v, count) in zip(tensor, counts) ] )
    
    convert( torch.arange(4), [3,2,1,3])
    tensor([0, 0, 0, 1, 1, 2, 3, 3, 3])