Search code examples
c++eigeneigen3

Custom scalar type in Eigen


I'm currently trying to set up a custom scalar type for use with the Eigen3 library (atm it is a simple wrapper around double). I have followed https://eigen.tuxfamily.org/dox/TopicCustomizing_CustomScalar.html to the best of my knowledge and the basic things work fine.

I need to solve eigenvalue-problems for the matrices with my custom type though and this is exactly where things start falling apart. My compiler spits the following error message at me:

/Eigen3/Eigen/src/Householder/Householder.h:131:18: error: no viable overloaded '-='
    this->row(0) -= tau * tmp;
    ~~~~~~~~~~~~ ^  ~~~~~~~~~
/Eigen3/Eigen/src/Eigenvalues/HessenbergDecomposition.h:314:10: note: in instantiation of function template specialization 'Eigen::MatrixBase<Eigen::Block<Eigen::Matrix<MyDouble, 2,
      2, 0, 2, 2>, -1, -1, false> >::applyHouseholderOnTheLeft<Eigen::VectorBlock<Eigen::Block<Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2>, 2, 1, true>, -1> >' requested here
        .applyHouseholderOnTheLeft(matA.col(i).tail(remainingSize-1), h, &temp.coeffRef(0));
         ^
/Eigen3/Eigen/src/Eigenvalues/HessenbergDecomposition.h:161:7: note: in instantiation of member function 'Eigen::HessenbergDecomposition<Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2>
      >::_compute' requested here
      _compute(m_matrix, m_hCoeffs, m_temp);
      ^
/Eigen3/Eigen/src/Eigenvalues/./RealSchur.h:271:10: note: in instantiation of function template specialization 'Eigen::HessenbergDecomposition<Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2>
      >::compute<Eigen::CwiseBinaryOp<Eigen::internal::scalar_quotient_op<MyDouble, MyDouble>, const Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2>, const Eigen::CwiseNullaryOp<Eigen::internal::scalar_constant_op<MyDouble>, const
      Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2> > > >' requested here
  m_hess.compute(matrix.derived()/scale);
         ^
/Eigen3/Eigen/src/Eigenvalues/EigenSolver.h:389:15: note: in instantiation of function template specialization 'Eigen::RealSchur<Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2>
      >::compute<Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2> >' requested here
  m_realSchur.compute(matrix.derived(), computeEigenvectors);
              ^
/Eigen3/Eigen/src/Eigenvalues/EigenSolver.h:156:7: note: in instantiation of function template specialization 'Eigen::EigenSolver<Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2>
      >::compute<Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2> >' requested here
      compute(matrix.derived(), computeEigenvectors);
      ^
/home/adam/Documents/Git/spin_phonon_coupling/src/test.cpp:205:52: note: in instantiation of function template specialization 'Eigen::EigenSolver<Eigen::Matrix<MyDouble, 2, 2, 0, 2, 2> >::EigenSolver<Eigen::Matrix<MyDouble, 2, 2, 0, 2,
      2> >' requested here
        Eigen::EigenSolver<Eigen::Matrix<MyDouble, 2, 2>> solver(test);
                                                          ^
/Eigen3/Eigen/src/Core/DenseBase.h:298:14: note: candidate template ignored: could not match 'EigenBase<type-parameter-0-0>' against 'MyDouble'
    Derived& operator-=(const EigenBase<OtherDerived> &other);
             ^
/Eigen3/Eigen/src/Core/MatrixBase.h:161:14: note: candidate template ignored: could not match 'MatrixBase<type-parameter-0-0>' against 'MyDouble'
    Derived& operator-=(const MatrixBase<OtherDerived>& other);
             ^
/Eigen3/Eigen/src/Core/MatrixBase.h:495:46: note: candidate template ignored: could not match 'ArrayBase<type-parameter-0-0>' against 'MyDouble'
    template<typename OtherDerived> Derived& operator-=(const ArrayBase<OtherDerived>& )

So the base-problem seems to be a missing overload for the -= operator. The problem is that I don't have a clue why. As far as I can see, I have defined the overloaded operator for my type, but if I understand the error correctly, it seems as if the internal Eigen-type is missing that overload...

Does anybody have an idea, what might be the problem here?


Here's the Code producing above error:

#include <iostream>
#include <cmath>

#include <Eigen/Dense>

class MyDouble {
    public:
        double value;
        MyDouble() : value() {};
        MyDouble(double val) : value(val) {};

        template<typename T>
            MyDouble operator+(T other) const {
                return value + other;
            }
        template<>
            MyDouble operator+(MyDouble other) const {
                return MyDouble(value + other.value);
            }

        template<typename T>
            MyDouble operator-(T other) const {
                return value - other;
            }
        template<>
            MyDouble operator-(MyDouble other) const {
                return MyDouble(value - other.value);
            }

        template<typename T>
            MyDouble operator*(T other) const {
                return value * other;
            }
        template<>
            MyDouble operator*(MyDouble other) const {
                return MyDouble(value * other.value);
            }

        template<typename T>
            MyDouble operator/(T other) const {
                return value / other;
            }
        template<>
            MyDouble operator/(MyDouble other) const {
                return MyDouble(value / other.value);
            }

        template<typename T>
            MyDouble& operator+=(T other) {
                value += other;
                return *this;
            }
        template<>
            MyDouble& operator+=(MyDouble other) {
                value += other.value;
                return *this;
            }

        template<typename T>
            MyDouble& operator-=(const T &other) {
                value -= other;
                return *this;
            }
        template<>
            MyDouble& operator-=(const MyDouble &other) {
                value -= other.value;
                return *this;
            }

        template<typename T>
            MyDouble& operator*=(T other) {
                value *= other.value;
                return *this;
            }
        template<>
            MyDouble& operator*=(MyDouble other) {
                value *= other.value;
                return *this;
            }

        template<typename T>
            MyDouble& operator/=(T other) {
                value /= other;
                return *this;
            }
        template<>
            MyDouble& operator/=(MyDouble other) {
                value /= other.value;
                return *this;
            }

        MyDouble operator-() const {
            return -value;
        }

        template<typename T>
            bool operator<(T other) const {
                return value < other;
            }
        template<>
            bool operator<(MyDouble other) const {
                return value < other.value;
            }

        template<typename T>
            bool operator>(T other) const {
                return value > other;
            }
        template<>
            bool operator>(MyDouble other) const {
                return value > other.value;
            }

        template<typename T>
            bool operator<=(T other) const {
                return value <= other;
            }
        template<>
            bool operator <=(MyDouble other) const {
                return value <= other.value;
            }

        template<typename T>
            bool operator>=(T other) const {
                return value >= other;
            }
        template<>
            bool operator>=(MyDouble other) const {
                return value >= other.value;
            }

        template<typename T>
            bool operator==(T other) const {
                return value == other;
            }
        template<>
            bool operator==(MyDouble other) const {
                return value == other.value;
            }

        template<typename T>
            bool operator!=(T other) const {
                return value != other;
            }
        template<>
            bool operator!=(MyDouble other) const {
                return value != other.value;
            }

        friend std::ostream& operator<<(std::ostream& out, const MyDouble& val) {
            out << val.value << " m";
            return out;
        }

        operator double() {
            return value;
        }
};

MyDouble sqrt(MyDouble val) {
    return std::sqrt(val.value);
}
MyDouble abs(MyDouble val) {
    return std::abs(val.value);
}
MyDouble abs2(MyDouble val) {
    return val * val;
}

namespace Eigen {
    template<> struct NumTraits<MyDouble>
        : NumTraits<double> // permits to get the epsilon, dummy_precision, lowest, highest functions
        {
            typedef MyDouble Real;
            typedef MyDouble NonInteger;
            typedef MyDouble Nested;
            enum {
                IsComplex = 0,
                IsInteger = 0,
                IsSigned = 1,
                RequireInitialization = 1,
                ReadCost = 1,
                AddCost = 3,
                MulCost = 3
            };
        };
    template<typename BinaryOp>
    struct ScalarBinaryOpTraits<MyDouble,double,BinaryOp> { typedef MyDouble ReturnType;  };

    template<typename BinaryOp>
    struct ScalarBinaryOpTraits<double,MyDouble,BinaryOp> { typedef MyDouble ReturnType;  };
}

int main() {
    Eigen::Matrix<MyDouble, 2, 2> test;
    test << 1, 2, 3, 4;

    Eigen::Matrix<double, 2, 2> test2;
    test2 << 5, 6, 7, 8;

    MyDouble a = 3;
    a -= 2;
    a -= MyDouble(3);

    Eigen::EigenSolver<Eigen::Matrix<MyDouble, 2, 2>> solver(test);
    std::cout << test.trace() << std::endl;
    std::cout << solver.eigenvalues() << std::endl;
}

Solution

  • Marc Glisse pointed out that my operator-overloads appeared to be kinda weird, so I rewrote them and the initial issue was gone. I don't have a clue as of why exactly this is the case though.

    Then I only had the problem that Eigen's isfinite function was not defined for my custom type, so I went ahead and added an implementation of it as well (though I'm not sure if that is actually a solid one).

    Anyways my code compiles now. Here's the modified version of my code:

    #include <iostream>
    #include <cmath>
    #include <complex>
    
    #include <Eigen/Dense>
    
    class MyDouble {
        public:
            double value;
            MyDouble() : value() {};
            MyDouble(double val) : value(val) {};
    
            template<typename T>
                MyDouble& operator+=(T rhs) {
                    value = static_cast<double>(value + rhs);
                    return *this;
                }
    
            template<typename T>
                MyDouble& operator-=(const T &rhs) {
                    value = static_cast<double>(value - rhs);
                    return *this;
                }
    
            template<typename T>
                MyDouble& operator*=(T rhs) {
                    value = static_cast<double>(value * rhs);
                    return *this;
                }
    
            template<typename T>
                MyDouble& operator/=(T rhs) {
                    value = static_cast<double>(value / rhs);
                    return *this;
                }
    
            MyDouble operator-() const {
                return -value;
            }
    
            friend std::ostream& operator<<(std::ostream& out, const MyDouble& val) {
                out << val.value << " m";
                return out;
            }
    
            explicit operator double() {
                return value;
            }
    };
    
    #define OVERLOAD_OPERATOR(op,ret) ret operator op(const MyDouble &lhs, const MyDouble &rhs) { \
            return lhs.value op rhs.value; \
        }
    
    OVERLOAD_OPERATOR(+, MyDouble)
    OVERLOAD_OPERATOR(-, MyDouble)
    OVERLOAD_OPERATOR(*, MyDouble)
    OVERLOAD_OPERATOR(/, MyDouble)
    
    OVERLOAD_OPERATOR(>, bool)
    OVERLOAD_OPERATOR(<, bool)
    OVERLOAD_OPERATOR(>=, bool)
    OVERLOAD_OPERATOR(<=, bool)
    OVERLOAD_OPERATOR(==, bool)
    OVERLOAD_OPERATOR(!=, bool)
    
    
    MyDouble sqrt(MyDouble val) {
        return std::sqrt(val.value);
    }
    MyDouble abs(MyDouble val) {
        return std::abs(val.value);
    }
    MyDouble abs2(MyDouble val) {
        return val * val;
    }
    bool isfinite(const MyDouble &) { return true; }
    
    namespace Eigen {
        template<> struct NumTraits<MyDouble>
            : NumTraits<double> // permits to get the epsilon, dummy_precision, lowest, highest functions
            {
                typedef MyDouble Real;
                typedef MyDouble NonInteger;
                typedef MyDouble Nested;
                enum {
                    IsComplex = 0,
                    IsInteger = 0,
                    IsSigned = 1,
                    RequireInitialization = 0,
                    ReadCost = 1,
                    AddCost = 3,
                    MulCost = 3
                };
            };
    
        template<typename BinaryOp>
        struct ScalarBinaryOpTraits<MyDouble,double,BinaryOp> { typedef MyDouble ReturnType;  };
    
        template<typename BinaryOp>
        struct ScalarBinaryOpTraits<double,MyDouble,BinaryOp> { typedef MyDouble ReturnType;  };
    }
    
    int main() {
        Eigen::Matrix<MyDouble, 2, 2> test;
        test << 1, 2, 3, 4;
    
        Eigen::Matrix<double, 2, 2> reference;
        reference << 1, 2, 3, 4;
    
        MyDouble a = 3;
        a += 2;
        a = 2 + a;
        a = a + 2;
        a -= 2;
        a -= MyDouble(3);
    
        a = a / a;
    
        std::complex<MyDouble> complexTest(3,4);
        complexTest *= 2;
    
        Eigen::EigenSolver<Eigen::Matrix<MyDouble, 2, 2>> solver(test);
        Eigen::EigenSolver<Eigen::Matrix<double, 2, 2>> refSolver(reference);
        std::cout << "MyDouble:" << std::endl;
        std::cout << test.trace() << std::endl;
        std::cout << solver.eigenvalues() << std::endl;
        std::cout << "\nRefernce:" << std::endl;
        std::cout << reference.trace() << std::endl;
        std::cout << refSolver.eigenvalues() << std::endl;
    }