Search code examples
pytorchmatrix-multiplicationcomplex-numbers

matrix multiplication for complex numbers in PyTorch


I am trying to multiply two complex matrices in PyTorch and it seems the torch.matmul functions is not added yet to PyTorch library for complex numbers.

Do you have any recommendation or is there another method to multiply complex matrices in PyTorch?


Solution

  • Currently torch.matmul is not supported for complex tensors such as ComplexFloatTensor but you could do something as compact as the following code:

    def matmul_complex(t1,t2):
        return torch.view_as_complex(torch.stack((t1.real @ t2.real - t1.imag @ t2.imag, t1.real @ t2.imag + t1.imag @ t2.real),dim=2))
    

    When possible avoid using for loops as these will result in much slower implementations. Vectorization is achieved by using built-in methods as demonstrated in the code I have attached. For example, your code takes roughly 6.1s on CPU while the vectorized version takes only 101ms (~60 times faster) for 2 random complex matrices with dimensions 1000 X 1000.

    Update:

    Since PyTorch 1.7.0 (as @EduardoReis mentioned) you can do matrix multiplication between complex matrices similarly to real-valued matrices as follows:

    t1 @ t2 (for t1, t2 complex matrices).