Search code examples
f#nestedpolymorphismfoldrecursive-type

How to implement “efficient generalized fold” in F#?


In the paper of Martin et al. I read about efficient generalized folds for nestet data types. The paper talks about Haskell and I want to try it in F#.

So far I managed to follow the Nest example including the implementation of gfold.

type Pair<'a> = 'a * 'a
type Nest<'a> = Nil | Cons of 'a * Nest<Pair<'a>>

let example =
    Cons(1,
        Cons((2, 3),
            Cons(((4, 5), (6, 7)),
                Nil
            )
        )
    )

let pair (f:'a -> 'b) ((a, b):Pair<'a>) : Pair<'b> = f a, f b

let rec nest<'a, 'r> (f:'a -> 'r) : Nest<'a> -> Nest<'r> = function
    | Nil -> Nil
    | Cons(x, xs) -> Cons(f x, nest (pair f) xs)

//val gfold : e:'r -> f:('a * 'r -> 'r) -> g:(Pair<'a> -> 'a) -> _arg1:Nest<'a> -> 'r
let rec gfold e f g : Nest<'a> -> 'r = function
    | Nil -> e
    | Cons(x, xs) ->
        f(x, gfold e f g (nest g xs))

let uncurry f (a, b) = f a b

let up = uncurry (+)

let sum = example |> gfold 0 up up

Unfortunately, gfold seems to have quadratic complexity and that's why the authors came up with efold. As you can probably guess, that's the one I couldn't get working. After fiddling with many type annotations, I came up with this version that only has a tiny squiggle left:

let rec efold<'a, 'b, 'r> (e:'r) (f:'a * 'r -> 'r) (g:(Pair<'a> -> Pair<'a>) -> 'a -> 'a) (h:_) (nest:Nest<'a>) : 'r =
    match nest with
    | Nil -> e
    | Cons(x, xs) -> f(h x, efold e f g ((g << pair) h) xs)
                                                        ^^

The only remaining unspecified type is the one of h. The compiler infers val h : ('a -> 'a) but I think there need to be different types.

The error message provided reads

Error Type mismatch. Expecting a
Nest<'a>
but given a
Nest<Pair<'a>>
The resulting type would be infinite when unifying ''a' and 'Pair<'a>'

With the correct type of h the error should vanish. But I don't understand enough Haskell to translate it to F#.

See also this discussion about a possible typo in the paper.


Update: This is what I understand from kvb's answer:

So h transforms an input type into an intermediate type, like in a regular fold where the accumulator may be of different type. g is then used to reduce two intermediate typed values to one while f gets an intermediate type and an input type to produce the output typed values. Of course e is also of that output type.

h is indeed directly applied to the values encountered during recursion. g on the other hand is only used to make h applicable to progressively deeper types.

Just be looking at the first f examples, by itself it doesn't seem to do much work apart from applying h and fuelling recursion. But in the sophisticated approach I can see that it is the most important one wrt. what comes out, i.e. it's the work horse.

Is that about right?


Solution

  • The correct definition of efold in Haskell is something like:

    efold :: forall n m b.
        (forall a. n a)->
        (forall a.(m a, n (Pair a)) -> n a)->
        (forall a.Pair (m a) -> m (Pair a))->
        (forall a.(a -> m b) -> Nest a -> n b) 
    efold e f g h Nil = e 
    efold e f g h (Cons (x,xs)) = f (h x, efold e f g (g . pair h) xs
    

    This can't be translated to F# in full generality because n and m are "higher-kinded types" - they are type constructors that create a type when given an argument - which aren't supported in F# (and have no clean representation in .NET).

    Interpretation

    Your update asks how to interpret the arguments to the fold. Perhaps the easiest way to see how the fold works is to expand out what happens when you apply the fold to your example. You would get something like this:

    efold e f g h example ≡
        f (h 1, f ((g << pair h) (2, 3), f ((g << pair (g << pair h)) ((4,5), (6,7)), e)))
    

    So h maps values into the type that can serve as f 's first agument. g is used to apply h to more deeply nested pairs (so that we can go from using h as a function of type a -> m b to Pair a -> m (Pair b) to Pair (Pair a) -> m (Pair (Pair b)) etc.), and f is repeatedly applied up the spine to combine the results of h with the results of nested calls to f. Finally, e is used exactly once, to serve as the seed of the most deeply nested call to f.

    I think this explanation mostly agrees with what you've deduced. f is certainly critical to combining the results of the different layers. But g matters, too, since it tells you how to combine the pieces within a layer (e.g. when summing the nodes, it needs to sum the left and right nested sums; if you wanted to use a fold to build a new nest where the values at each level are reversed from those of the input, you would use a g which looks roughly like fun (a,b) -> b,a).

    Simple approach

    One option is to create specialized implementations of efold for each n, m pair you care about. For example, if we want to sum the lengths of lists contained in a Nest then n _ and m _ will both just be int. We can generalize slightly, to the case where n _ and m _ don't depend on their arguments:

    let rec efold<'n,'m,'a> (e:'n) (f:'m*'n->'n) (g:Pair<'m> -> 'm) (h:'a->'m) : Nest<'a> -> 'n = function
    | Nil -> e
    | Cons(x,xs) -> f (h x, efold e f g (g << (pair h)) xs)
    
    let total = efold 0 up up id example
    

    On the other hand, if n and m do use their arguments, then you'd need to define a separate specialization (plus, you may need to create new types for each polymorphic argument, since F#'s encoding of higher-rank types is awkward). For instance, to collect a nest's values into a list you want n 'a = list<'a> and m 'b = 'b. Then instead of defining new types for the argument type of e we can observe that the only value of type forall 'a.list<'a> is [], so we can write:

    type ListIdF =
        abstract Apply : 'a * list<Pair<'a>> -> list<'a>
    
    type ListIdG =
        abstract Apply : Pair<'a> -> Pair<'a>
    
    let rec efold<'a,'b> (f:ListIdF) (g:ListIdG) (h:'a -> 'b) : Nest<'a> -> list<'b> = function
    | Nil -> []
    | Cons(x,xs) -> f.Apply(h x, efold f g (pair h >> g.Apply) xs)
    
    let toList n = efold { new ListIdF with member __.Apply(a,l) = a::(List.collect (fun (x,y) -> [x;y]) l) } { new ListIdG with member __.Apply(p) = p } id n
    

    Sophisticated approach

    While F# doesn't directly support higher-kinded types, it turns out that it's possible to simulate them in a somewhat faithful way. This is the approach taken by the Higher library. Here's what a minimal version of that would look like.

    We create a type App<'T,'a> which will represent some type application T<'a>, but where we'll create a dummy companion type that can serve as the first type argument to App<_,_>:

    type App<'F, 'T>(token : 'F, value : obj) = 
        do
            if obj.ReferenceEquals(token, Unchecked.defaultof<'F>) then
                raise <| new System.InvalidOperationException("Invalid token")
    
        // Apply the secret token to have access to the encapsulated value
        member self.Apply(token' : 'F) : obj =
            if not (obj.ReferenceEquals(token, token')) then
                raise <| new System.InvalidOperationException("Invalid token")
            value 
    

    Now we can define some companion types for type constructors we care about (and these can generally live in some shared library):

    // App<Const<'a>, 'b> represents a value of type 'a (that is, ignores 'b)
    type Const<'a> private () =
        static let token = Const ()
        static member Inj (value : 'a) =
            App<Const<'a>, 'b>(token, value)
        static member Prj (app : App<Const<'a>, 'b>) : 'a =
            app.Apply(token) :?> _
    
    // App<List, 'a> represents list<'a>
    type List private () = 
        static let token = List()
        static member Inj (value : 'a list) =
            App<List, 'a>(token, value)
        static member Prj (app : App<List, 'a>) : 'a list =
            app.Apply(token) :?> _
    
    // App<Id, 'a> represents just a plain 'a
    type Id private () =
        static let token = Id()
        static member Inj (value : 'a) =
            App<Id, 'a>(token, value)
        static member Prj (app : App<Id, 'a>) : 'a =
            app.Apply(token) :?> _
    
    // App<Nest, 'a> represents a Nest<'a>
    type Nest private () =
        static let token = Nest()
        static member Inj (value : Nest<'a>) =
            App<Nest, 'a>(token, value)
        static member Prj (app : App<Nest, 'a>) : Nest<'a> =
            app.Apply(token) :?> _
    

    Now we can define the higher-rank types for the arguments of the efficient fold once and for all:

    // forall a. n a
    type E<'N> =
        abstract Apply<'a> : unit -> App<'N,'a>
    
    // forall a.(m a, n (Pair a)) -> n a)
    type F<'M,'N> =
        abstract Apply<'a> : App<'M,'a> * App<'N,'a*'a> -> App<'N,'a>
    
    // forall a.Pair (m a) -> m (Pair a))
    type G<'M> =
        abstract Apply<'a> : App<'M,'a> * App<'M,'a> -> App<'M,'a*'a>
    

    so that the fold is just:

    let rec efold<'N,'M,'a,'b> (e:E<'N>) (f:F<'M,'N>) (g:G<'M>) (h:'a -> App<'M,'b>) : Nest<'a> -> App<'N,'b> = function
    | Nil -> e.Apply()
    | Cons(x,xs) -> f.Apply(h x, efold e f g (g.Apply << pair h) xs)
    

    Now to call efold we need to sprinkle in some calls to the various Inj and Prj methods, but otherwise everything looks much as we'd expect:

    let toList n = 
        efold { new E<_> with member __.Apply() = List.Inj [] } 
              { new F<_,_> with member __.Apply(m,n) = Id.Prj m :: (n |> List.Prj |> List.collect (fun (x,y) -> [x;y])) |> List.Inj }
              { new G<_> with member __.Apply(m1,m2) = (Id.Prj m1, Id.Prj m2) |> Id.Inj }
              Id.Inj
              n
        |> List.Prj
    
    let sumElements n =
        efold { new E<_> with member __.Apply() = Const.Inj 0 }
              { new F<_,_> with member __.Apply(m,n) = Const.Prj m + Const.Prj n |> Const.Inj }
              { new G<_> with member __.Apply(m1,m2) = Const.Prj m1 + Const.Prj m2 |> Const.Inj }
              Const.Inj
              n
        |> Const.Prj
    
    let reverse n = 
        efold { new E<_> with member __.Apply() = Nest.Inj Nil }
              { new F<_,_> with member __.Apply(m,n) = Cons(Id.Prj m, Nest.Prj n) |> Nest.Inj }
              { new G<_> with member __.Apply(m1,m2) = (Id.Prj 2, Id.Prj m1) |> Id.Inj }
              Id.Inj
              n
        |> Nest.Prj
    

    Hopefully the pattern here is clear: in each object expression, the application method projects out each argument, operates on them, and then injects the result back into an App<_,_> type. With some inline magic, we can make this look even more consistent (at the cost of a few type annotations):

    let inline (|Prj|) (app:App< ^T, 'a>) = (^T : (static member Prj : App< ^T, 'a> -> 'b) app)
    let inline prj (Prj x) = x
    let inline inj x = (^T : (static member Inj : 'b -> App< ^T, 'a>) x)
    
    let toList n = 
        efold { new E<List> with member __.Apply() = inj [] } 
              { new F<Id,_> with member __.Apply(Prj m, Prj n) = m :: (n |> List.collect (fun (x,y) -> [x;y])) |> inj }
              { new G<_> with member __.Apply(Prj m1,Prj m2) = (m1, m2) |> inj }
              inj
              n
        |> prj
    
    let sumElements n =
        efold { new E<Const<_>> with member __.Apply() = inj 0 }
              { new F<Const<_>,_> with member __.Apply(Prj m, Prj n) = m + n |> inj }
              { new G<_> with member __.Apply(Prj m1,Prj m2) = m1 + m2 |> inj }
              inj
              n
        |> prj
    
    let reverse n = 
        efold { new E<_> with member __.Apply() = Nest.Inj Nil }
              { new F<Id,_> with member __.Apply(Prj m,Prj n) = Cons(m, n) |> inj }
              { new G<_> with member __.Apply(Prj m1,Prj m2) = (m2, m1) |> inj }
              inj
              n
        |> prj