Search code examples
pythonnumpymeaneinsum

Speedup with einsum and average over multiple arrays


Given two matrices A and B with 4 indices each, A[i,j,n_i,m], B[i,j,n_i,m], I need to perform the tensor product of them for the first two indices, while keeping the remaning two. For this task, I use einsum. Moreover, I need to average the result over the last index, for which I just use .mean() to the specific axis.

C=np.einsum('ij...,kl...->ikjl...',A[:,:,:,:],B[:,:,:,:]).reshape(1024,1024,len(t_list),Mmax).mean(axis=3)   

While this works, I would like to know if there is any optimal way of doing this because right now this is quite slow. I am afraid my bottleneck is the size of the matrices: While the original A and B matrices have shape (32,32,1000,10000), the resulting matrix is C=(1024,1024,1000), which is quite is costly.

However, using einsum and then averaging using .mean(axis=3) is still much faster than looping over Mmax, creating the arrays and then taking the average.

I should comment that what I am really after is the reduced matrix, or the partial traces of this matrix, i.e.

from qutip import *
dim=2*np.ones(2*5)
for n_i,n in enumerate(t_list):
    C_s=C[:,:,n_i]
    Qobj(C_s,dims=[dim,dim]).ptrace([0,1])

If possible, I would like to do this operation without having to allocate C directly, which I think is what makes the whole process slow.


Solution

  • As mentioned in the comments, you were copying the input arrays which is unnecessary. In addition, as you suggested, it is possible to calculate the partial trace of the outer product directly using np.einsum, without allocating C. It won't help code readability though...

    Looking at the implementation of ptrace and the arguments you passed to it (maybe next time try to distil an external function like this down to its Numpy equivalent yourself - it's probably easier to understand with some context), it appears you will end up calling _ptrace_dense (as found here), and specifically the else part of that function:

    rhomat = np.trace(Q.full()
        .reshape(rd + rd)
        .transpose(qtrace + [nd + q for q in qtrace] +
                    sel + [nd + q for q in sel])
        .reshape([np.prod(dtrace, dtype=np.int32),
                np.prod(dtrace, dtype=np.int32),
                np.prod(dkeep, dtype=np.int32),
                np.prod(dkeep, dtype=np.int32)]))
    

    As you can see, this does some complex reshaping/transposition then takes the np.trace of the result. You can take advantage of the fact that the trace of an outer product is actually an inner product, both of which are exactly what np.einsum was made for. This removes a large number of multiplications and requires much less memory.

    You can also sum the final dimension within your np.einsum and divide the output by the length of the final dimension to get rid of your .mean(axis=3).

    In order to do it all in one np.einsum step, you'll have to first reshape one of your input arrays. Ultimately it boils down to:

    A = A.reshape([4, 8, 4, 8, len(t_list), Mmax])
    
    X = np.einsum('ikok...q,ll...q->io...', A, B).reshape((4, 4, len(t_list))) / Mmax
    

    X now holds the stacked results of your for loop, i.e. all the ptrace results on the final axis.

    This transformation isn't particularly intuitive, at least to me. However, the steps to get here are fairly easy, if long-winded.

    I'll work through what I did in the specific case: You have two input arrays with shape (32, 32) and want the partial trace for the first two subsets/qubits/whatever. Broadcasting this operation across your other dimensions should be obvious (read here if it isn't).

    Summarising what you're trying to do, it boils down to:

    C = np.einsum('ij,kl->ikjl',A,B)
    # While you reshape this output, this has no effect because internally `ptrace` reshapes it internally
    X = C.reshape([2, 2, 2, ..., 2]) # Reshape it into a 2x2x...x2 array
    Y = X.transpose([2, 3, ..., 9, 12, 13, ..., 19, 0, 1, 10, 11]) # Move the first two axes of the first and second original axes to the end
    Z = Y.reshape([256, 256, 4, 4])
    output = np.trace(Z)
    

    This can be simplified significantly using the axis1 and axis2 (not sure why the Qutip library doesn't do this):

    C = np.einsum('ij,kl->ikjl',A,B)
    X = C.reshape([4, 256, 4, 256])
    output = np.trace(X, axis1=1, axis2=3)
    

    Replacing np.trace:

    C = np.einsum('ij,kl->ikjl', A, B)
    X = C.reshape([4, 256, 4, 256])
    output = np.einsum('ijkj->ik', X)
    

    Two slightly non-intuitive things that may help explain the next few steps:

    A = np.random.rand(4, 4)
    B = np.random.rand(4, 4)
    X = np.einsum('ij,kl->ijkl', A, B)  # Note this is ijkl NOT ikjl
    Y = np.einsum('i,j->ij', A.flatten(), B.flatten())
    assert np.allclose(X.flatten(), Y.flatten())
    

    and (though not as important):

    A = np.random.rand(4, 4)
    B = np.random.rand(4, 4)
    X = np.einsum('ij,kl->ikjl', A, B)
    Y = np.einsum('ij,kl->ijkl', A, B).transpose((0, 2, 1, 3)) # Or .swapaxes(1, 2)
    assert np.allclose(X.flatten(), Y.flatten())
    

    This first point means that:

    A = np.random.rand(32, 32)
    B = np.random.rand(32, 32)
    C = np.einsum('ij,kl->ikjl', A, B)
    A = A.reshape((4, 8, 4, 8))
    B = B.reshape((4, 8, 4, 8))
    D = np.einsum('ij kl,mn op->ij mn kl op', A, B) # Spaces have no meaning, just help map the dimension groups to previous names,
    # So i=ij, j=kl, k=mn and l=op
    assert np.allclose(C.flatten(), D.flatten())
    

    This is helpful because in order to merge the two np.einsums, you need to be able to split A and B (well, maybe not B, more on that later) into 32x32 for the first np.einsum and 4x256 for the second, as previously shown. As such, A and B's shape have to be (4, 8, 4, 8), since 4*8=32, satisfying the first requirement, and 8*4*8=256, satisfying the second. That changes the previous simplification of the problem to:

    A = np.random.rand(4, 8, 4, 8)
    B = np.random.rand(4, 8, 4, 8)
    C = np.einsum('ij kl,mn op->ij mn kl op', A, B)
    output = np.einsum('i jmn k jmn->ik', C).reshape((4, 4))
    

    Now all you have to do is merge the two np.einsums. This is easy, since ijmnklop from the first output map directly to ijmnkjmn of the second input. Therefore:

    A = np.random.rand(4, 8, 4, 8)
    B = np.random.rand(4, 8, 4, 8)
    output = np.einsum('i j kj,m n mn->ik', A, B).reshape((4, 4))
    

    It should be obvious here that B actually doesn't need 4 dimensions, only two. So this finally comes to:

    A = np.random.rand(4, 8, 4, 8)
    B = np.random.rand(32, 32)
    output = np.einsum('i j kj,l l->ik', A, B).reshape((4, 4))
    

    Obviously instead of using np.random like me, reshape your data to the correct shape.

    This isn't a particularly satisfying explanation as the output code seems completely unrelated to the input code. However, I hope the steps help you understand how you can work through it and generalise it to other np.einsum cases.

    Regarding your specific problem of generalising to other subsets, you can reshape A to 2x2x...x2 and select arbitrary indices from there. You don't have to modify B, given the previous (unexpected) simplification. For example, if you wanted 0 and 3 instead of 0 and 1:

    A = A.reshape((2,)*10)
    B = B.reshape((32, 32))
    output = np.einsum('i jk l m o jk p m,z z->i l o p', A, B).reshape((4, 4))
    

    Notice how all dimensions you want to "keep" are named with unique letters and repeated in the output, whereas all dimensions you want to trace are repeated in the input and don't appear in the output. B is simply traced, regardless of which dimensions are chosen.

    Finally, you can write a function to generate the subscripts automatically:

    def get_subscripts(keep_indices: List[int], length: int, start_at='b', trace_char='a') -> str:
        start = ord(start_at)
        in_subs = list(range(start, start + length))*2
        out_subs = [start + k for k in keep_indices]
        for i, k in enumerate(keep_indices):
            in_subs[length + k] = start + length + i
            out_subs.append(start + length + i)
        return ''.join(chr(x) for x in in_subs) + ',' + trace_char*2 + '->' + ''.join(chr(x) for x in out_subs)
    
    subsets = [0, 1, 3]
    A = A.reshape((2,)*10)
    B = B.reshape((32, 32))
    output = np.einsum(get_subscripts(subsets, 5), A, B).reshape((2**len(subsets),)*2)
    

    Fun problem!

    Edit: If you want to generalise it to any subsets, you will need to reshape both arrays and modify the subscripts function slightly:

    def get_subscripts(keep_indices: List[int], length: int, start_at='a') -> str:
        start = ord(start_at)
        assert all(k < 2*length for k in keep_indices)
        in_subs_1 = list(range(start, start + length))*2
        in_subs_2 = list(range(start + length, start + 2*length))*2
        out_subs = [start + k for k in keep_indices]
        for i, k in enumerate(keep_indices):
            next_sub = start + 2*length + i
            if k < length:
                in_subs_1[length + k] = next_sub
            else:
                in_subs_2[k] = next_sub
            out_subs.append(next_sub)
        return ''.join(chr(x) for x in in_subs_1) + ',' + ''.join(chr(x) for x in in_subs_2) + '->' + ''.join(chr(x) for x in out_subs)
    
    A = A.reshape((2,)*10)
    B = B.reshape((2,)*10)
    
    output = np.einsum(get_subscripts(subsets, 5), A, B).reshape((2**len(subsets),)*2)
    

    Basically just apply the same logic used for the first array to the second. The only other change is to move some of the offsets.

    Edit 2: After writing this answer, I wrote a Python package to do all this automatically, but forgot to link to it here. This may be relevant to someone.