Search code examples
c++automatic-differentiation

C++ reverse automatic differentiation with graph


I'm trying to make a reverse mode automatic differentiation in C++.

The idea I came up with is that each variable that results of an operation on one or two other variables, is going to save the gradients in a vector.

This is the code :

class Var {
    private:
        double value;
        char character;
        std::vector<std::pair<double, const Var*> > children;

    public:
        Var(const double& _value=0, const char& _character='_') : value(_value), character(_character) {};
        void set_character(const char& character){ this->character = character; }

        // computes the derivative of the current object with respect to 'var'
        double gradient(Var* var) const{
            if(this==var){
                return 1.0;
            }

            double sum=0.0;
            for(auto& pair : children){
                // std::cout << "(" << this->character << " -> " <<  pair.second->character << ", " << this << " -> " << pair.second << ", weight=" << pair.first << ")" << std::endl;
                sum += pair.first*pair.second->gradient(var);
            }
            return sum;
        }

        friend Var operator+(const Var& l, const Var& r){
            Var result(l.value+r.value);
            result.children.push_back(std::make_pair(1.0, &l));
            result.children.push_back(std::make_pair(1.0, &r));
            return result;
        }

        friend Var operator*(const Var& l, const Var& r){
            Var result(l.value*r.value);
            result.children.push_back(std::make_pair(r.value, &l));
            result.children.push_back(std::make_pair(l.value, &r));
            return result;
        }

        friend std::ostream& operator<<(std::ostream& os, const Var& var){
            os << var.value;
            return os;
        }
};

I tried to run the code like this :

int main(int argc, char const *argv[]) {
    Var x(5,'x'), y(6,'y'), z(7,'z');

    Var k = z + x*y;
    k.set_character('k');

    std::cout << "k = " << k << std::endl;
    std::cout << "∂k/∂x = " << k.gradient(&x) << std::endl;
    std::cout << "∂k/∂y = " << k.gradient(&y) << std::endl;
    std::cout << "∂k/∂z = " << k.gradient(&z) << std::endl;

    return 0;
}

The computational graph that should be build is the following :

       x(5)   y(6)              z(7)
         \     /                 /
 ∂w/∂x=y  \   /  ∂w/∂y=x        /
           \ /                 /
          w=x*y               /
             \               /  ∂k/∂z=1
              \             /
      ∂k/∂w=1  \           /
                \_________/
                     |
                   k=w+z

Then, if I want to calculate ∂k/∂x for instance, I have to multiply the gradients following the edges, and sum the result for every edge. This is done recursively by double gradient(Var* var) const. So I have ∂k/∂x = ∂k/∂w * ∂w/∂x + ∂k/∂z * ∂z/∂x.

The problem

If I have intermediate calculation such as x*y here, something goes wrong. When std::cout is uncommented here is the output :

k = 37
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂x = 0
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂y = 5
(k -> z, 0x7ffeb3345740 -> 0x7ffeb3345710, weight=1)
(k -> _, 0x7ffeb3345740 -> 0x7ffeb3345770, weight=1)
(_ -> x, 0x7ffeb3345770 -> 0x7ffeb33456b0, weight=0)
(_ -> y, 0x7ffeb3345770 -> 0x7ffeb33456e0, weight=5)
∂k/∂z = 1

It prints which variable is connected to which one, then their addresses, and the weight of the connection (which should be the gradient).

The problem is weight=0 between x and the intermediate variable which holds the result of x*y (and which I denoted as w in my graph). I have no idea why this one is zero and not the other weight connected to y.

Another thing I noticed, is that if you switch the lines in operator* like so :

result.children.push_back(std::make_pair(1.0, &r));
result.children.push_back(std::make_pair(1.0, &l));

Then it's the y connections which cancels.

Thanks in advance for any help.


Solution

  • The line:

    Var k = z + x*y;
    

    Calls operator*, which returns a Var temporary, which is then used for the r argument to operator+, wherein a pair stores the address of the temporary. After the line completes, k children include a pointer to the place the temporary had been, but it no longer exists.


    While it doesn't protect against the above mistake, you can create the intended behaviour by avoiding the unnamed temporary...

    Var xy = x * y;
    xy.set_character('*');
    Var k = z + xy;
    k.set_character('k');
    

    ...with which your program produces:

    k = 37
    ∂k/∂x = 6
    ∂k/∂y = 5
    ∂k/∂z = 1
    

    A better fix may be to capture the children by value.


    As a general tip for catching such mistakes... when your program seems to be doing something inexplicable (and/or crashing), try running it under a memory error detector such as valgrind. For your code, the report starts off with:

    ==22137== Invalid read of size 8
    ==22137==    at 0x1090EA: Var::gradient(Var*) const (in /home/median/so/deriv)
    ==22137==    by 0x109109: Var::gradient(Var*) const (in /home/median/so/deriv)
    ==22137==    by 0x108E12: main (in /home/median/so/deriv)
    ==22137==  Address 0x5b82cd0 is 0 bytes inside a block of size 32 free'd
    ==22137==    at 0x4C3123B: operator delete(void*) (in /usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
    ==22137==    by 0x109FC1: __gnu_cxx::new_allocator<std::pair<double, Var const*> >::deallocate(std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x109CDD: std::allocator_traits<std::allocator<std::pair<double, Var const*> > >::deallocate(std::allocator<std::pair<double, Var const*> >&, std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x109963: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_deallocate(std::pair<double, Var const*>*, unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x1097BC: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::~_Vector_base() (in /home/median/so/deriv)
    ==22137==    by 0x1095EA: std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::~vector() (in /home/median/so/deriv)
    ==22137==    by 0x109161: Var::~Var() (in /home/median/so/deriv)
    ==22137==    by 0x108D95: main (in /home/median/so/deriv)
    ==22137==  Block was alloc'd at
    ==22137==    at 0x4C3017F: operator new(unsigned long) (in /usr/lib/valgrind/vgpreload_memcheck-amd64-linux.so)
    ==22137==    by 0x10A153: __gnu_cxx::new_allocator<std::pair<double, Var const*> >::allocate(unsigned long, void const*) (in /home/median/so/deriv)
    ==22137==    by 0x10A060: std::allocator_traits<std::allocator<std::pair<double, Var const*> > >::allocate(std::allocator<std::pair<double, Var const*> >&, unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x109F03: std::_Vector_base<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_allocate(unsigned long) (in /home/median/so/deriv)
    ==22137==    by 0x109A8D: void std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::_M_realloc_insert<std::pair<double, Var const*> >(__gnu_cxx::__normal_iterator<std::pair<double, Var const*>*, std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > > >, std::pair<double, Var const*>&&) (in /home/median/so/deriv)
    ==22137==    by 0x1098CF: void std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::emplace_back<std::pair<double, Var const*> >(std::pair<double, Var const*>&&) (in /home/median/so/deriv)
    ==22137==    by 0x10973F: std::vector<std::pair<double, Var const*>, std::allocator<std::pair<double, Var const*> > >::push_back(std::pair<double, Var const*>&&) (in /home/median/so/deriv)
    ==22137==    by 0x109520: operator*(Var const&, Var const&) (in /home/median/so/deriv)
    ==22137==    by 0x108D6F: main (in /home/median/so/deriv)
    

    Another way to catch it can be to add logging in a destructor so you know when the object addresses mentioned in your logging are no longer valid.