Search code examples
haskellfunctorfunction-composition

Is the Yoneda Lemma only useful from a theoretical point of view?


For instance, loop fusion can be obtained with Yoneda:

newtype Yoneda f a =
    Yoneda (forall b. (a -> b) -> f b)

liftYo :: (Functor f) => f a -> Yoneda f a
liftYo x = Yoneda $ \f -> fmap f x

lowerYo :: (Functor f) => Yoneda f a -> f a
lowerYo (Yoneda y) = y id

instance Functor (Yoneda f) where
    fmap f (Yoneda y) = Yoneda $ \g -> y (g . f)

loopFusion = lowerYo . fmap f . fmap g . liftYo

But I could have just written loopFusion = fmap (f . g). Why would I use Yoneda at all? Are there other use cases?


Solution

  • Well, in this case you could have done fusion by hand, because the two fmaps are "visible" in the source code, but the point is that Yoneda does the transformation at runtime. It's a dynamic thing, most useful when you don't know how many times you will need to fmap over a structure. E.g. consider lambda terms:

    data Term v = Var v | App (Term v) (Term v) | Lam (Term (Maybe v))
    

    The Maybe under Lam represents the variable bound by the abstraction; in the body of a Lam, the variable Nothing refers to the bound variable, and all variables Just v represent the ones bound in the environment. (>>=) :: Term v -> (v -> Term v') -> Term v' represents substitution—each variable can be replaced with a Term. However, when replacing a variable inside a Lam, all the variables in the produced Term need to be wrapped in Just. E.g.

    Lam $ Lam $ Var $ Just $ Just $ ()
      >>= \() -> App (Var "f") (Var "x")
    =
    Lam $ Lam $ App (Var $ Just $ Just "f") (Var $ Just $ Just "x")
    

    The naive implementation of (>>=) goes like this:

    (>>=) :: Term v -> (v -> Term v') -> Term v'
    Var x >>= f = f x
    App l r >>= f = App (l >>= f) (r >>= f)
    Lam b >>= f = Lam (b >>= maybe (Var Nothing) (fmap Just . f))
    

    But, written like this, every Lam that (>>=) goes under adds an fmap Just to f. If I had a Var v buried under 1000 Lams, then I would end up calling fmap Just and iterating over the new f v term 1000 times! I can't just pull your trick and fuse multiple fmaps into one, by hand, because there's only one fmap in the source code being called multiple times.

    Yoneda can ease the pain:

    bindTerm :: Term v -> (v -> Yoneda Term v') -> Term v'
    bindTerm (Var x) f = lowerYoneda (f x)
    bindTerm (App l r) f = App (bindTerm l f) (bindTerm r f)
    bindTerm (Lam b) f =
      Lam (bindTerm b (maybe (liftYoneda $ Var Nothing) (fmap Just . f)))
    
    (>>=) :: Term v -> (v -> Term v') -> Term v'
    t >>= f = bindTerm t (liftYoneda . f)
    

    Now, the fmap Just is free; it's just a wonky function composition. The actual iteration over the produced Term is in the lowerYoneda, which is only called once for each Var. To reiterate: the source code nowhere contains anything of the form fmap f (fmap g x). Such forms only arise at runtime, dynamically, depending on the argument to (>>=). Yoneda can rewrite that, at runtime, to fmap (f . g) x, even though you can't rewrite it like that in the source code. Further, you can add Yoneda to existing code with minimal changes to it. (There is, however, a drawback: lowerYoneda is always called exactly once for each Var, which means e.g. Var v >>= f = fmap id (f v) where it was just f v, before.)