Batched tensor creation inside torch.vmap
Matthew Barrera
I want to create a tensor with torch.zeros based on the shape of an input to the function. Then I want to vectorize the function with torch.vmap.
Something like this:
poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
def polycompanion(polynomial): deg = polynomial.shape[-1] - 2 companion = torch.zeros((deg+1, deg+1)) companion[1:,:-1] = torch.eye(deg) companion[:,-1] = -1. * polynomial[:-1] / polynomial[-1] return companion
polycompanion_vmap = torch.vmap(polycompanion)
print(polycompanion_vmap(poly_batched))The problem is that the batched version will not work, because companion won’t be a BatchedTensor, unlike polynomial, which was the input.
There is a workaround:
poly_batched = torch.tensor([[1, 2, 3, 4], [1, 2, 3, 4]])
def polycompanion(polynomial,companion): deg = companion.shape[-1] - 1 companion[1:,:-1] = torch.eye(deg) companion[:,-1] = -1. * polynomial[:-1] / polynomial[-1] return companion
polycompanion_vmap = torch.vmap(polycompanion)
print(polycompanion_vmap(poly_batched, torch.zeros(poly_batched.shape[0],poly_batched.shape[-1]-1, poly_batched.shape[-1]-1)))Output:
tensor([[[ 0.0000, 0.0000, -0.2500], [ 1.0000, 0.0000, -0.5000], [ 0.0000, 1.0000, -0.7500]], [[ 0.0000, 0.0000, -0.2500], [ 1.0000, 0.0000, -0.5000], [ 0.0000, 1.0000, -0.7500]]])But this is ugly.
Is there a solution for this? Will this be supported in the future?
Note: If you use torch.zeros_like on an input to the function it works and creates BatchedTensor but this doesn’t help me here.
Thanks in advance for the help!
Related questions 1301 How do I get file creation and modification date/times? 0 Using batched input with tf.math.invert_permutation 0 Vectorize to Apply Function to 3d Array Related questions 1301 How do I get file creation and modification date/times? 0 Using batched input with tf.math.invert_permutation 0 Vectorize to Apply Function to 3d Array 2 numpy vectorize a function to accepts vectors of different lengths and return the tensor result 312 How to print the value of a Tensor object in TensorFlow? 1 Pytorch Simple Linear Sigmoid Network not learning 0 tf.shape() returns a 2-d tensor instead of 1-d 171 PyTorch preferred way to copy a tensor 1 (pytorch / mse) How can I change the shape of tensor? Load 6 more related questions Show fewer related questions Reset to default