Search code examples
pythonarraysnumpymultiplication

Matrix Multiplication with Object Arrays in Python


I am wondering how matrix multiplication can be supported in numpy with arrays of dtype=object. I have homomorphically encrypted numbers that are encapsulated in a class Ciphertext for which I have overriden the basic math operators like __add__, __mul__ etc.

I have created numpy array where each entry is an instance of my class Ciphertext and numpy understands how to broadcast addition and multiplication operations just fine.

    encryptedInput = builder.encrypt_as_array(np.array([6,7])) # type(encryptedInput) is <class 'numpy.ndarray'>
    encryptedOutput = encryptedInput + encryptedInput
    builder.decrypt(encryptedOutput)                           # Result: np.array([12,14])

However, numpy won't let me do matrix multiplications

out = encryptedInput @ encryptedInput # TypeError: Object arrays are not currently supported

I don't quite understand why this happens considering that addition and multiplication works. I guess it has something to do with numpy not being able to know the shape of the object, since it could be a list or something fance.

Naive Solution: I could write my own class that extends ndarray and overwrite the __matmul__ operation, but I would probably lose out on performance and also this approach entails implementing broadcasting etc., so I would basically reinvent the wheel for something that should work as it is right now.

Question: How can I use the standard matrix multiplication provided by numpy on arrays with dtype=objects where the objects behave exactly like numbers?

Thank you in advance!


Solution

  • For whatever reason matmul doesn't work, but the tensordot function works as expected.

    encryptedInput = builder.encrypt_as_array(np.array([6,7]))
    out = np.tensordot(encryptedInput, encryptedInput, axes=([1,0])) 
        # Correct Result: [[ 92. 105.]
        #                  [120. 137.]]
    

    Now it's just a hassle to adjust the axes. I still wonder whether this is actually faster than a naive implementation with for loops.