Search code examples
cfunctionfunction-pointersfunction-call

Returning a Function Pointer from a Function and Calling it with a Pointer. How does this Work Exactly?


So take the following code, I was reading some lecture notes on function pointers, and I came across this:

int (*Convert(const char code))(int, int) {
    if (code == ‘+’) return ∑ // Takes two ints, and adds
    if (code == ‘-’) return &Difference; // Takes two ints, and subtracts
} 

int main () {
    int (*ptr)(int,int);
    ptr = Convert(‘+’);
    printf( “%d\n”, ptr(2,4));
} 

I'm usually used to seeing something like this when calling a function that returns a function pointer, and to me, this makes sense since I have all the parameters laid out here, the char, and the two int's:

Convert('-')(5, 6);

But in the way it was written in the notes, I can't really grasp what's exactly going on here. Can someone tell how exactly does this work? Does it have to do with assigning (*ptr)(int, int) the function's address or something?


Solution

  • Can someone tell how exactly does this work? Does it have to do with assigning (*ptr)(int, int) the function's address or something?

    Function Convert() returns a pointer to a function -- either a pointer to Sum() or a pointer to Difference(), depending on its argument (or it terminates without specifying a return value, which is bad news for you if you do anything at all with the return value). That function pointer is stored in variable ptr, which is declared to have the same type as Convert() returns. The pointed-to function can then be called by use of the function-call operator, ().

    Perhaps it would be a bit clearer if rewritten in this equivalent way, with use of a typedef:

    typedef int (*op_function)(int, int);
    
    op_function Convert(const char code) {
        if (code == ‘+’) return ∑ // Takes two ints, and adds
        if (code == ‘-’) return &Difference; // Takes two ints, and subtracts
    } 
    
    int main () {
        op_function ptr;
        ptr = Convert(‘+’);
        printf( “%d\n”, ptr(2,4));
    }