Search code examples
c++gradient-descent

How to write a general version gradient_descent algorithm in c++?


I want to write a general version of gradient descent algorithm in c++ to pass the following gtest.

...

#include <cmath>

TEST(HW6Test, TEST1) {
    auto min1 = q1::gradient_descent(0.01, 0.1, cos);
    EXPECT_NEAR(min1, 3.14, 0.1);

    auto min2 = q1::gradient_descent(0.01, 0.01, cos);
    EXPECT_NEAR(min2, 3.14, 0.01);
}

TEST(HW6Test, TEST2) {
    auto min = q1::gradient_descent(0.01, 0.01, [](double a){return sin(a)+cos(a);});
    EXPECT_NEAR(min, -2.36, 0.01);
}

TEST(HW6Test, TEST3) {
    struct Func
    {
        double operator()(double a) {return cos(a);}
    };
    auto min = q1::gradient_descent(0.01, 0.01, Func{});
    EXPECT_NEAR(min, 3.14, 0.01);
}

TEST(HW6Test, TEST4) {
    struct Func
    {
        double operator()(double a) {return sin(a);}
    };
    auto min = q1::gradient_descent<double, Func>(0.0, 0.01);
    EXPECT_NEAR(min, -1.57, 0.01);
}

As you see, the math-function can be a pointer to function, lambda function or a functor. I want some suggestions about how to solve this problem.

I have tried :

namespace q1 
{
    template <typename T, typename Func>
    const T& gradient_descent(const T& init_value, const T& step, Func func) {

    }
};

But it didn't work for these cases.

auto min1 = q1::gradient_descent(0.01, 0.1, cos);
auto min = q1::gradient_descent<double, Func>(0.0, 0.01);

Solution

  • Looks like you want a pair of overloads like this:

    template<typename T, typename Func>
    T gradient_descent(const T& init_value, const T& step, Func func = {}) {
        ...;
    }
    
    template<typename T>
    T gradient_descent(const T& init_value, const T& step, T func(T)) {
        // Call the other overload
        return gradient_descent<T, T(T)>(init_value, step, func);
    }
    

    So when you call gradient_descent(0.0, 0.0, std::cos), the first overload doesn't apply because Func can't be deduced from an overloaded function set, but T func(T) (with T = double) will pick the std::cos<double> to pass as a function pointer.

    And the first overload has a default = {} added to make q1::gradient_descent<double, Func>(0.0, 0.01); work by default constructing the functor.

    Note that since C++20, std::cos/std::sin are not addressable functions (meaning you can't pass a function pointer to them to your function), so you have to wrap them in a lambda:

    q1::gradient_descent(0.01, 0.1, [](double x) { return cos(x); });