I want to create a tensor with torch.zeros
based on the shape of an input to the function. Then I want to vectorize the function with torch.vmap
.
Something like this:
poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
def polycompanion(polynomial):
deg = polynomial.shape[-1] - 2
companion = torch.zeros((deg+1, deg+1))
companion[1:,:-1] = torch.eye(deg)
companion[:,-1] = -1. * polynomial[:-1] / polynomial[-1]
return companion
polycompanion_vmap = torch.vmap(polycompanion)
print(polycompanion_vmap(poly_batched))
The problem is that the batched version will not work, because companion
won’t be a BatchedTensor
, unlike polynomial
, which was the input.
There is a workaround:
poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
def polycompanion(polynomial,companion):
deg = companion.shape[-1] - 1
companion[1:,:-1] = torch.eye(deg)
companion[:,-1] = -1. * polynomial[:-1] / polynomial[-1]
return companion
polycompanion_vmap = torch.vmap(polycompanion)
print(polycompanion_vmap(poly_batched, torch.zeros(poly_batched.shape[0],poly_batched.shape[-1]-1, poly_batched.shape[-1]-1)))
Output:
tensor([[[ 0.0000, 0.0000, -0.2500],
[ 1.0000, 0.0000, -0.5000],
[ 0.0000, 1.0000, -0.7500]],
[[ 0.0000, 0.0000, -0.2500],
[ 1.0000, 0.0000, -0.5000],
[ 0.0000, 1.0000, -0.7500]]])
But this is ugly.
Is there a solution for this? Will this be supported in the future?
Note: If you use torch.zeros_like
on an input to the function it works and creates BatchedTensor
but this doesn’t help me here.
Thanks in advance for the help!
The problem is that the batched version will not work, because companion won’t be a BatchedTensor, unlike polynomial, which was the input.
Turns out you can make companion
a BatchedTensor
using torch.tensor.new_zeros. So instead of companion = torch.zeros((deg+1, deg+1))
, companion = polynomial.new_torch_zeros((deg+1, deg+1))
will work.
Issue about this on torch github: issue
This might change in the future as suggested by a developer.