Search code examples
matlaboopfunction-handle

MATLAB: Pass class function handle to ode45()?


I have a class function that uses ODE45 to solve some equations. I have another, private class function that represents the odefunction that ODE45 needs to solve. However, I can't figure out how to pass a handle of the class's ode function to ODE45. Here is the sample code:

class ODESolver < handle

    methods (Access = public)

        function obj = RunODE(obj, t, y0)
            [~, Z] = ode45(@ODEFunction, t, y0);
        end

    end

    methods (Access = private)

        function dy = ODEFunction(t,y)
            % Calculate dy here.
        end

    end

end

When I run this, i get an error saying:

Undefined function 'ODEFunction' for input arguments of type 'double'.

If I move ODEFunction outside of the class and put it in its own *.m file, the code runs fine. I've also tried using "@obj.ODEFunction" in the ode45 call, but then it says:

Too many input arguments.

What's the best way to keep ODEFunction inside of my class and still be able to pass it's handle to ode45?


Solution

  • Your private ODEFunction is not a static method, so you should write:

    classdef ODESolver < handle
    
        methods (Access = public)
    
            function obj = RunODE(obj, t, y0)
                [~, Z] = ode45(@(tt, yy)obj.ODEFunction(tt, yy), t, y0);
            end
    
        end
    
        methods (Access = private)
    
            function dy = ODEFunction(obj, t,y)
                dy = 0.1; % Calculate dy here.
            end
    
        end
    
    end
    

    NB: You also forgot to pass obj as first argument of private ODEFunction ... I'm writing example with static method and will paste it here once tested.

    Edit

    Here is how you should write things if you intented to have private static ODEFunction in your class:

    classdef ODESolver < handle
    
        methods (Access = public)
    
            function obj = RunODE(obj, t, y0)
                [~, Z] = ode45(@(tt, yy)ODESolver.ODEFunction(tt, yy), t, y0);
            end
    
        end
    
        methods (Static, Access = private)
    
            function dy = ODEFunction(t,y)
                dy = 0.1; % Calculate dy here.
            end
    
        end
    
    end