Search code examples
javascripthaskellcurryinglambda-calculuspartial-application

How to correctly curry a function in JavaScript?


I wrote a simple curry function in JavaScript which works correctly for most cases:

const curry = (f, ...a) => a.length < f.length
    ? (...b) => curry(f, ...a, ...b)
    : f(...a);

const add = curry((a, b, c) => a + b + c);

const add2 = add(2);

const add5 = add2(3);

console.log(add5(5));

However, it doesn't work for the following case:

// length :: [a] -> Number
const length = a => a.length;

// filter :: (a -> Bool) -> [a] -> [a]
const filter = curry((f, a) => a.filter(f));

// compose :: (b -> c) -> (a -> b) -> a -> c
const compose = curry((f, g, x) => f(g(x)));

// countWhere :: (a -> Bool) -> [a] -> Number
const countWhere = compose(compose(length), filter);

According to the following question countWhere is defined as (length .) . filter:

What does (f .) . g mean in Haskell?

Hence I should be able to use countWhere as follows:

const odd = n => n % 2 === 1;

countWhere(odd, [1,2,3,4,5]);

However, instead of returning 3 (the length of the array [1,3,5]), it returns a function. What am I doing wrong?


Solution

  • The problem with your curry function (and for most curry functions that people write in JavaScript) is that it doesn't handle extra arguments correctly.

    What curry does

    Suppose f is a function and f.length is n. Let curry(f) be g. We call g with m arguments. What should happen?

    1. If m === 0 then just return g.
    2. If m < n then partially apply f to the m new arguments, and return a new curried function which accepts the remaining n - m arguments.
    3. Otherwise apply f to the m arguments and return the result.

    This is what most curry functions do, and this is wrong. The first two cases are right, but the third case is wrong. Instead, it should be:

    1. If m === 0 then just return g.
    2. If m < n then partially apply f to the m new arguments, and return a new curried function which accepts the remaining n - m arguments.
    3. If m === n then apply f to the m arguments. If the result is a function then curry the result. Finally, return the result.
    4. If m > n then apply f to the first n arguments. If the result is a function then curry the result. Finally, apply the result to the remaining m - n arguments and return the new result.

    The problem with most curry functions

    Consider the following code:

    const countWhere = compose(compose(length), filter);
    
    countWhere(odd, [1,2,3,4,5]);
    

    If we use the incorrect curry functions, then this is equivalent to:

    compose(compose(length), filter, odd, [1,2,3,4,5]);
    

    However, compose only accepts three arguments. The last argument is dropped:

    const compose = curry((f, g, x) =>f(g(x)));
    

    Hence, the above expression evaluates to:

    compose(length)(filter(odd));
    

    This further evaluates to:

    compose(length, filter(odd));
    

    The compose function expects one more argument which is why it returns a function instead of returning 3. To get the correct output you need to write:

    countWhere(odd)([1,2,3,4,5]);
    

    This is the reason why most curry functions are wrong.

    The solution using the correct curry function

    Consider the following code again:

    const countWhere = compose(compose(length), filter);
    
    countWhere(odd, [1,2,3,4,5]);
    

    If we use the correct curry function, then this is equivalent to:

    compose(compose(length), filter, odd)([1,2,3,4,5]);
    

    Which evaluates to:

    compose(length)(filter(odd))([1,2,3,4,5]);
    

    Which further evaluates to (skipping an intermediate step):

    compose(length, filter(odd), [1,2,3,4,5]);
    

    Which results in:

    length(filter(odd, [1,2,3,4,5]));
    

    Producing the correct result 3.

    The implementation of the correct curry function

    Implementing the correct curry function in ES6 is straightforward:

    const curry = (f, ...a) => {
        const n = f.length, m = a.length;
        if (n === 0) return m > n ? f(...a) : f;
        if (m === n) return autocurry(f(...a));
        if (m < n) return (...b) => curry(f, ...a, ...b);
        return curry(f(...a.slice(0, n)), ...a.slice(n));
    };
    
    const autocurry = (x) => typeof x === "function" ? curry(x) : x;
    

    Note that if the length of the input function is 0 then it's assumed to be curried.

    Implications of using the correct curry function

    Using the correct curry function allows you to directly translate Haskell code into JavaScript. For example:

    const id = curry(a => a);
    
    const flip = curry((f, x, y) => f(y, x));
    

    The id function is useful because it allows you to partially apply a non-curried function easily:

    const add = (a, b) => a + b;
    
    const add2 = id(add, 2);
    

    The flip function is useful because it allows you to easily create right sections in JavaScript:

    const sub = (a, b) => a - b;
    
    const sub2 = flip(sub, 2); // equivalent to (x - 2)
    

    It also means that you don't need hacks like this extended compose function:

    What's a Good Name for this extended `compose` function?

    You can simply write:

    const project = compose(map, pick);
    

    As mentioned in the question, if you want to compose length and filter then you use the (f .) . g pattern:

    What does (f .) . g mean in Haskell?

    Another solution is to create higher order compose functions:

    const compose2 = compose(compose, compose);
    
    const countWhere = compose2(length, fitler);
    

    This is all possible because of the correct implementation of the curry function.

    Extra food for thought

    I usually use the following chain function when I want to compose a chain of functions:

    const chain = compose((a, x) => {
        var length = a.length;
        while (length > 0) x = a[--length](x);
        return x;
    });
    

    This allows you to write code like:

    const inc = add(1);
    
    const foo = chain([map(inc), filter(odd), take(5)]);
    
    foo([1,2,3,4,5,6,7,8,9,10]); // [2,4,6]
    

    Which is equivalent to the following Haskell code:

    let foo = map (+1) . filter odd . take 5
    
    foo [1,2,3,4,5,6,7,8,9,10]
    

    It also allows you to write code like:

    chain([map(inc), filter(odd), take(5)], [1,2,3,4,5,6,7,8,9,10]); // [2,4,6]
    

    Which is equivalent to the following Haskell code:

    map (+1) . filter odd . take 5 $ [1,2,3,4,5,6,7,8,9,10]
    

    Hope that helps.