Search code examples
c++odeintegrator

Seperating ODE and ODE solver in base and derived classes


I think the question a bit long so I think it's better to consider over simplified version of it first:

There is two classes A and B. B inherite from A. There is a member function in B (add) that need to be run using a member function in A.

class A;
typedef int(A::*operation)(int c);
typedef void (A::*myfunction)(operation, int);

class A
{
public:
    int a;
    int b;

    int do_something(myfunction f, operation op)
    {
        (this->*f)(op, 1);
    }

    void dummy_name(operation op, int a)
    {
        int c = (this->*op)(a);
    }
};

class B : public A
{
public:
    int a, b;

    B(int a, int b): a(a), b(b) {}
    int add(int c)
    {
        return a+b+c;
    }

};

int main()
{
    B inst_B(2, 5);
    inst_B.do_something(&A::dummy_name, &B::add);
}

simple.cpp:45:41: error: cannot convert ‘int (B::*)(int)’ to ‘operation’ {aka ‘int (A::*)(int)’}
   45 |     inst_B.do_something(&A::dummy_name, &B::add);
      |                                         ^~~~~~~
      |                                         |
      |                                         int (B::*)(int)
simple.cpp:17:47: note:   initializing argument 2 of ‘void A::do_something(myfunction, operation)’
   17 |     void do_something(myfunction f, operation op)
      |                                     ~~~~~~~~~~^~

To write a simple ode solver and to avoid coping integrator inside the class of the model, for every model including a system of ordinary differential equations, I have seperated the sovler and equations in two class, while model inherite from ode solver.

class HarmonicOscillator: public ODE_Sover

This is a simplfied example, which contain a few parameters. To avoid passing many parameters and abstraction, I prefered to define the ODE in a class.

I have also used two function templates for derivative (right hand side of the dy/dt = f'(y)) and for integrator (here only euler integrator). This is what I came up with:

#include <iostream>
#include <assert.h>
#include <random>
#include <vector>
#include <string>

using std::string;
using std::vector;

class ODE_Solver;
class HarmonicOscillator;
typedef vector<double> dim1;
typedef dim1 (ODE_Solver::*derivative)(const dim1 &, dim1&, const double t);
typedef void (ODE_Solver::*Integrator)(derivative, dim1 &, dim1&, const double t);


class ODE_Solver
{
    public: 
    ODE_Solver()
    {}

    double t;
    double dt;
    dim1 state;
    dim1 dydt;

    void integrate(Integrator integrator, 
                   derivative ode_system, 
                   const int N,
                   const double ti, 
                   const double tf, 
                   const double dt)
    {
        dim1 dydt(N);
        const size_t num_steps = int((tf-ti) / dt);
        for (size_t step = 0; step < num_steps; ++step)
        {   
            double t = step * dt;
            (this->*integrator)(ode_system, state, dydt, t);
            // print state
        }
    }

    void eulerIntegrator(derivative RHS, dim1 &y, dim1 &dydt, const double t)
    {
        int n = y.size();
        (this->*RHS)(y, dydt, t);
        for (int i = 0; i < n; i++)
            y[i] += dydt[i] * dt;
    }
};

class HarmonicOscillator: public ODE_Solver
{

public:
    int N;
    double dt;
    double gamma;
    string method;
    dim1 state;

    // constructor
    HarmonicOscillator(int N,
                       double gamma,
                       dim1 state
                       ) : N {N}, gamma{gamma}, state{state}
    { }
    //---------------------------------------------------//
    dim1 dampedOscillator(const dim1 &x, dim1&dxdt, const double t)
    {
        dxdt[0] = x[1];
        dxdt[1] = -x[0] - gamma * x[1];

        return dxdt;
    }
};

//-------------------------------------------------------//

int main(int argc, char **argv)
{
    const int N = 2;
    const double gamma = 0.05;
    const double t_iinit = 0.0;
    const double t_final = 10.0;
    const double dt = 0.01;

    dim1 x0{0.0, 1.0};
    HarmonicOscillator ho(N, gamma, x0);
    ho.integrate(&ODE_Solver::eulerIntegrator,
                 &HarmonicOscillator::dampedOscillator, 
                 N, t_iinit, t_final, dt);

    return 0;
}

I get these errors:

example.cpp: In function ‘int main(int, char**)’:
example.cpp:93:18: error: cannot convert ‘dim1 (HarmonicOscillator::*)(const dim1&, dim1&, double)’ {aka ‘std::vector<double> (HarmonicOscillator::*)(const std::vector<double>&, std::vector<double>&, double)’} to ‘derivative’ {aka ‘std::vector<double> (ODE_Solver::*)(const std::vector<double>&, std::vector<double>&, double)’}
   93 |                  &HarmonicOscillator::dampedOscillator,
      |                  ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
      |                  |
      |                  dim1 (HarmonicOscillator::*)(const dim1&, dim1&, double) {aka std::vector<double> (HarmonicOscillator::*)(const std::vector<double>&, std::vector<double>&, double)}
example.cpp:29:31: note:   initializing argument 2 of ‘void ODE_Solver::integrate(Integrator, derivative, int, double, double, double)’
   29 |                    derivative ode_system,
      |                    ~~~~~~~~~~~^~~~~~~~~~

If I defined the ode at the same class of ode solver it works link to github. So what's your idea?


Solution

  • There are several issues with your code. First, when you want to pass arbitrary functions as arguments to other functions, consider using std::function.

    Second, inheritance should be used to declare "is-a" relationships. Since a harmonic oscillator is not an ODE solver, don't use inheritance for this. There's also not a "has-a" relationship (so composition is also not appropriate), instead the solver acts on a given function, so the most appropriate thing to do is to pass the harmonic oscillator as a parameter to the solver function.

    An example of what the code might look like:

    class HarmonicOscillator {
        ...
    public:
        ...
        double operator()(double t) {
            ...
            return /* value at time t */;
        }
    };
    
    double integrate(std::function<double(double)> func, double start, double end, double dt) {
        double sum = 0;
    
        for (double t = start; t < end; t += dt)
            sum += func(t) * dt;
    
        return sum;
    }
    

    And then you just call it like so:

    HarmonicOscillator ho(...);
    auto result = integrate(ho, t_iinit, t_final, dt);
    

    The above might not do exactly what you want, but that is the structure of the code I think you should aim for.

    If you want to be able to handle function that not only take a double and return a double, but arbitrary types, you could make integrate() a template:

    template <typename Function, typename T>
    auto integrate(Function func, T start, T end, T dt) {
        decltype(func(start)) sum{};
    
        for (T t = start; t < end; t += dt)
            sum += func(t) * dt;
    
        return sum;
    }
    

    This works if you create proper types for the input and output values that support arithmetic operations, it won't work with your dim1. I recommend you try to find a library that implements mathematical vector types, like Eigen for example.