Search code examples
javascriptfunctional-programmingmonadsmonad-transformerstrampolines

How to use the Trampoline type as the base monad of a transformer more efficiently?


I have an array transformer type that exhibits interleaved effect layers to ensure a lawful effect implementation. You can easily read the structure from the type's of operation const arrOfT = of => x => of([of(x)]).

The type implements an effectful fold as its basic operation. I use a left fold, because the underlying array type is inherently strict:

const arrFoldT = chain => f => init => mmx =>
  chain(mmx) (mx => {
    const go = (acc, i) =>
      i === mx.length
        ? acc
        : chain(mx[i]) (x =>
            go(f(acc) (x), i + 1))
//          ^^^^^^^^^^^^^^^^^^^^^ non-tail call position
    return go(init, 0);
  });

As you can see the implementation is not stack safe. However, stack safety is just another computational effect that can be encoded through a monad. I implemented one for the Trampoline type:

const monadRec = o => {
  while (o.tag === "Chain")
    o = o.f(o.x);

  return o.tag === "Of"
    ? o.x
    : _throw(new TypeError("unknown case"));
};

const recChain = mx => fm =>
  mx.tag === "Chain" ? Chain(mx.x) (x => recChain(mx.f(x)) (fm))
    : mx.tag === "Of" ? fm(mx.x)
    : _throw(new TypeError("unknown case"));

const Chain = x => f =>
  ({tag: "Chain", f, x});

const Of = x =>
  ({tag: "Of", x});

While the implementations are straightforward the application is not. I am pretty sure I am applying it all wrong:

const mmx = Of(
  Array(1e5)
    .fill(Chain(1) (x => Of(x))));
//                  ^^^^^^^^^^^^ redundant continuation

const main = arrFoldT(recChain)
  (acc => y => recMap(x => x + y) (acc))
    (Of(0))
      (mmx);

monadRec(main); // 100000

I need to use Chain when creating the large effectful array, because Of signals the the control flow to break out of the trampoline. With Chain on the other hand I have to specifiy a redundant continuation.

My first idea was to flip Chain's arguments and rely on partial application, but this doesn't work with the current implemenetation.

Is there a way to use the type more efficiently?

Here is a working example:

// ARRAYT

const arrFoldT = chain => f => init => mmx =>
  chain(mmx) (mx => {
    const go = (acc, i) =>
      i === mx.length
        ? acc
        : chain(mx[i]) (x =>
            go(f(acc) (x), i + 1))

    return go(init, 0);
  });

// TRAMPOLINE

const monadRec = o => {
  while (o.tag === "Chain")
    o = o.f(o.x);

  return o.tag === "Of"
    ? o.x
    : _throw(new TypeError("unknown case"));
};

const Chain = x => f =>
  ({tag: "Chain", f, x});

const Of = x =>
  ({tag: "Of", x});

// Functor

const recMap = f => tx =>
  Of(f(tx.x));

// Monad

const recChain = mx => fm =>
  mx.tag === "Chain" ? Chain(mx.x) (x => recChain(mx.f(x)) (fm))
    : mx.tag === "Of" ? fm(mx.x)
    : _throw(new TypeError("unknown case"));

const recOf = Of;

// MAIN

const mmx = Of(
  Array(1e5)
    .fill(Chain(1) (x => Of(x))));

const main = arrFoldT(recChain)
  (acc => y => recMap(x => x + y) (acc))
    (Of(0))
      (mmx);

console.log(
  monadRec(main)); // 100000


Solution

  • First, the definition of your array monad transformer is wrong.

    ArrayT m a = m (Array (m a))
    

    The above type definition does not correctly interleave the underlying monad.

    Following is an example value of the above data type.

    of([of(1), of(2), of(3)])
    

    There are several problems with this data type.

    1. There is no effect for the end of the array.
    2. The effects are not ordered. Hence, they can be executed in any order.
    3. The underlying monad wraps the individual elements as well as the entire array. This is just wrong.

    Following is an example value of the correct array monad transformer type.

    of([1, of([2, of([3, of([])])])])
    

    Note that.

    1. There is an effect for the end of the array.
    2. The effects are ordered. This is because the data type is defined recursively.
    3. The underlying monad wraps the individual steps of the array. It doesn't wrap the entire array again.

    Now, I understand why you want to define ArrayT m a = m (Array (m a)). If m = Identity then you get back an actual Array a, which supports random access of elements.

    of([of(1), of(2), of(3)]) === [1, 2, 3]
    

    On the other hand, the recursive array monad transformer type returns a linked list when m = Identity.

    of([1, of([2, of([3, of([])])])]) === [1, [2, [3, []]]]
    

    However, there's no way to create a lawful array monad transformer type which also returns an actual array when the underlying monad is Identity. This is because monad transformers are inherently algebraic data structures, and arrays are not algebraic.

    The closest you can get is by defining ArrayT m a = Array (m a). However, this would only satisfy the monad laws when the underlying monad is commutative.

    Just remember, when defining a monad transformer data type.

    1. The underlying monad must wrap at most one value at a time.
    2. The underlying monad must be nested, to correctly order and interleave effects.

    Coming back, the Trampoline monad is just the Free monad. We can define it as follows.

    // pure : a -> Free a
    const pure = value => ({ constructor: pure, value });
    
    // bind : Free a -> (a -> Free b) -> Free b
    const bind = monad => arrow => ({ constructor: bind, monad, arrow });
    
    // thunk : (() -> Free a) -> Free a
    const thunk = eval => ({ constructor: thunk, eval });
    
    // MonadFree : Monad Free
    const MonadFree = { pure, bind };
    
    // evaluate : Free a -> a
    const evaluate = expression => {
        let expr = expression;
        let stack = null;
    
        while (true) {
            switch (expr.constructor) {
                case pure:
                    if (stack === null) return expr.value;
                    expr = stack.arrow(expr.value);
                    stack = stack.stack;
                    break;
                case bind:
                    stack = { arrow: expr.arrow, stack };
                    expr = expr.monad;
                    break;
                case thunk:
                    expr = expr.eval();
            }
        }
    };
    

    I'll also copy my implementation of the array monad transformer from my previous answer.

    // Step m a = null | { head : a, tail : ListT m a }
    // ListT m a = m (Step m a)
    
    // nil : Monad m -> ListT m a
    const nil = M => M.pure(null);
    
    // cons : Monad m -> a -> ListT m a -> ListT m a
    const cons = M => head => tail => M.pure({ head, tail });
    
    // foldr : Monad m -> (a -> m b -> m b) -> m b -> ListT m a -> m b
    const foldr = M => f => a => m => M.bind(m)(step =>
        step ? f(step.head)(foldr(M)(f)(a)(step.tail)) : a);
    

    Thus, when the underlying monad is Free then the operations are stack safe.

    // replicate :: Number -> a -> ListT Free a
    const replicate = n => x => n ?
        cons(MonadFree)(x)(thunk(() => replicate(n - 1)(x))) :
        nil(MonadFree);
    
    // map : (a -> b) -> Free a -> Free b
    const map = f => m => bind(m)(x => pure(f(x)));
    
    // add : Number -> Free Number -> Free Number
    const add = x => map(y => x + y);
    
    // result : Free Number
    const result = foldr(MonadFree)(add)(pure(0))(replicate(1000000)(1));
    
    console.log(evaluate(result)); // 1000000
    

    Putting it all together.

    // pure : a -> Free a
    const pure = value => ({ constructor: pure, value });
    
    // bind : Free a -> (a -> Free b) -> Free b
    const bind = monad => arrow => ({ constructor: bind, monad, arrow });
    
    // thunk : (() -> Free a) -> Free a
    const thunk = eval => ({ constructor: thunk, eval });
    
    // MonadFree : Monad Free
    const MonadFree = { pure, bind };
    
    // evaluate : Free a -> a
    const evaluate = expression => {
        let expr = expression;
        let stack = null;
    
        while (true) {
            switch (expr.constructor) {
                case pure:
                    if (stack === null) return expr.value;
                    expr = stack.arrow(expr.value);
                    stack = stack.stack;
                    break;
                case bind:
                    stack = { arrow: expr.arrow, stack };
                    expr = expr.monad;
                    break;
                case thunk:
                    expr = expr.eval();
            }
        }
    };
    
    // Step m a = null | { head : a, tail : ListT m a }
    // ListT m a = m (Step m a)
    
    // nil : Monad m -> ListT m a
    const nil = M => M.pure(null);
    
    // cons : Monad m -> a -> ListT m a -> ListT m a
    const cons = M => head => tail => M.pure({ head, tail });
    
    // foldr : Monad m -> (a -> m b -> m b) -> m b -> ListT m a -> m b
    const foldr = M => f => a => m => M.bind(m)(step =>
        step ? f(step.head)(foldr(M)(f)(a)(step.tail)) : a);
    
    // replicate :: Number -> a -> ListT Free a
    const replicate = n => x => n ?
        cons(MonadFree)(x)(thunk(() => replicate(n - 1)(x))) :
        nil(MonadFree);
    
    // map : (a -> b) -> Free a -> Free b
    const map = f => m => bind(m)(x => pure(f(x)));
    
    // add : Number -> Free Number -> Free Number
    const add = x => map(y => x + y);
    
    // result : Free Number
    const result = foldr(MonadFree)(add)(pure(0))(replicate(1000000)(1));
    
    console.log(evaluate(result)); // 1000000