Search code examples
pythonpytorchvectorizationtensor

Batched tensor creation inside torch.vmap


I want to create a tensor with torch.zeros based on the shape of an input to the function. Then I want to vectorize the function with torch.vmap.

Something like this:

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])

def polycompanion(polynomial):
    deg = polynomial.shape[-1] - 2
    companion = torch.zeros((deg+1, deg+1))
    companion[1:,:-1] = torch.eye(deg)
    companion[:,-1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion

polycompanion_vmap = torch.vmap(polycompanion)
print(polycompanion_vmap(poly_batched))

The problem is that the batched version will not work, because companion won’t be a BatchedTensor, unlike polynomial, which was the input.

There is a workaround:

poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])

def polycompanion(polynomial,companion):
    deg = companion.shape[-1] - 1
    companion[1:,:-1] = torch.eye(deg)
    companion[:,-1] = -1. * polynomial[:-1] / polynomial[-1]
    return companion

polycompanion_vmap = torch.vmap(polycompanion)

print(polycompanion_vmap(poly_batched, torch.zeros(poly_batched.shape[0],poly_batched.shape[-1]-1, poly_batched.shape[-1]-1)))

Output:

tensor([[[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]],

        [[ 0.0000,  0.0000, -0.2500],
         [ 1.0000,  0.0000, -0.5000],
         [ 0.0000,  1.0000, -0.7500]]])

But this is ugly.

Is there a solution for this? Will this be supported in the future?

Note: If you use torch.zeros_like on an input to the function it works and creates BatchedTensor but this doesn’t help me here.

Thanks in advance for the help!


Solution

  • The problem is that the batched version will not work, because companion won’t be a BatchedTensor, unlike polynomial, which was the input.

    Turns out you can make companion a BatchedTensor using torch.tensor.new_zeros. So instead of companion = torch.zeros((deg+1, deg+1)), companion = polynomial.new_torch_zeros((deg+1, deg+1)) will work.

    Issue about this on torch github: issue

    This might change in the future as suggested by a developer.