Search code examples
pythonpytorchtensortorch

Tensor power and multiplication in pytorch


I have a matrix A and a tensor b of size (1,3) - so a vector of size 3.

I want to compute

C = b1 * A + b2 * A^2 + b3 * A^3 where ^n is the n-th power of A.

At the end, C should have the same shape as A. How can I do this efficiently?


Solution

  • Let's try:

    A = torch.ones(1,2,3)
    b_vals = torch.tensor([2,3,4])
    powers = torch.tensor([1,2,3])
    
    C = (A[...,None]**powers + b_vals).sum(-1)
    

    Output:

    tensor([[[12., 12., 12.],
             [12., 12., 12.]]])