Search code examples
f#treemonadstail-recursioncontinuations

How can I make this F# function not cause a stack overflow


I've written an interesting function in F# which can traverse and map any data structure (much like the everywhere function available in Haskell's Scrap Your Boilerplate). Unfortunately it quickly causes a stack overflow on even fairly small data structures. I was wondering how I can convert it to a tail recursive version, continuation passing style version or an imperative equivalent algorithm. I believe F# supports monads, so the continuation monad is an option.

// These are used for a 50% speedup
let mutable tupleReaders : List<System.Type * (obj -> obj[])> = []
let mutable unionTagReaders : List<System.Type * (obj -> int)> = []
let mutable unionReaders : List<(System.Type * int) * (obj -> obj[])> = []
let mutable unionCaseInfos : List<System.Type * Microsoft.FSharp.Reflection.UnionCaseInfo[]> = []
let mutable recordReaders : List<System.Type * (obj -> obj[])> = []

(*
    Traverses any data structure in a preorder traversal
    Calls f, g, h, i, j which determine the mapping of the current node being considered

    WARNING: Not able to handle option types
    At runtime, option None values are represented as null and so you cannot determine their runtime type.

    See http://stackoverflow.com/questions/21855356/dynamically-determine-type-of-option-when-it-has-value-none
    http://stackoverflow.com/questions/13366647/how-to-generalize-f-option
*)
open Microsoft.FSharp.Reflection
let map5<'a,'b,'c,'d,'e,'z> (f:'a->'a) (g:'b->'b) (h:'c->'c) (i:'d->'d) (j:'e->'e) (src:'z) =
    let ft = typeof<'a>
    let gt = typeof<'b>
    let ht = typeof<'c>
    let it = typeof<'d>
    let jt = typeof<'e>

    let rec drill (o:obj) : obj =
        if o = null then
            o
        else
            let ot = o.GetType()
            if FSharpType.IsUnion(ot) then
                let tag = match List.tryFind (fst >> ot.Equals) unionTagReaders with
                              | Some (_, reader) ->
                                   reader o
                              | None ->
                                   let newReader = FSharpValue.PreComputeUnionTagReader(ot)
                                   unionTagReaders <- (ot, newReader)::unionTagReaders
                                   newReader o
                let info = match List.tryFind (fst >> ot.Equals) unionCaseInfos with
                               | Some (_, caseInfos) ->
                                   Array.get caseInfos tag
                               | None ->
                                   let newCaseInfos = FSharpType.GetUnionCases(ot)
                                   unionCaseInfos <- (ot, newCaseInfos)::unionCaseInfos
                                   Array.get newCaseInfos tag
                let vals = match List.tryFind (fun ((tau, tag'), _) -> ot.Equals tau && tag = tag') unionReaders with
                               | Some (_, reader) ->
                                   reader o
                               | None ->
                                   let newReader = FSharpValue.PreComputeUnionReader info
                                   unionReaders <- ((ot, tag), newReader)::unionReaders
                                   newReader o
                FSharpValue.MakeUnion(info, Array.map traverse vals)
            elif FSharpType.IsTuple(ot) then
                let fields = match List.tryFind (fst >> ot.Equals) tupleReaders with
                                 | Some (_, reader) ->
                                     reader o
                                 | None ->
                                     let newReader = FSharpValue.PreComputeTupleReader(ot)
                                     tupleReaders <- (ot, newReader)::tupleReaders
                                     newReader o
                FSharpValue.MakeTuple(Array.map traverse fields, ot)
            elif FSharpType.IsRecord(ot) then
                let fields = match List.tryFind (fst >> ot.Equals) recordReaders with
                                 | Some (_, reader) ->
                                     reader o
                                 | None ->
                                     let newReader = FSharpValue.PreComputeRecordReader(ot)
                                     recordReaders <- (ot, newReader)::recordReaders
                                     newReader o
                FSharpValue.MakeRecord(ot, Array.map traverse fields)
            else
                o

    and traverse (o:obj) =
        let parent =
            if o = null then
                o
            else
                let ot = o.GetType()
                if ft = ot || ot.IsSubclassOf(ft) then
                    f (o :?> 'a) |> box
                elif gt = ot || ot.IsSubclassOf(gt) then
                    g (o :?> 'b) |> box
                elif ht = ot || ot.IsSubclassOf(ht) then
                    h (o :?> 'c) |> box
                elif it = ot || ot.IsSubclassOf(it) then
                    i (o :?> 'd) |> box
                elif jt = ot || ot.IsSubclassOf(jt) then
                    j (o :?> 'e) |> box
                else
                    o
        drill parent
    traverse src |> unbox : 'z

Solution

  • Try this (I just used continuation function as parameter):

    namespace Solution
    
    [<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>]
    [<AutoOpen>]
    module Solution =
    
        // These are used for a 50% speedup
        let mutable tupleReaders : List<System.Type * (obj -> obj[])> = []
        let mutable unionTagReaders : List<System.Type * (obj -> int)> = []
        let mutable unionReaders : List<(System.Type * int) * (obj -> obj[])> = []
        let mutable unionCaseInfos : List<System.Type * Microsoft.FSharp.Reflection.UnionCaseInfo[]> = []
        let mutable recordReaders : List<System.Type * (obj -> obj[])> = []
    
        (*
            Traverses any data structure in a preorder traversal
            Calls f, g, h, i, j which determine the mapping of the current node being considered
    
            WARNING: Not able to handle option types
            At runtime, option None values are represented as null and so you cannot determine their runtime type.
    
            See http://stackoverflow.com/questions/21855356/dynamically-determine-type-of-option-when-it-has-value-none
            http://stackoverflow.com/questions/13366647/how-to-generalize-f-option
        *)
        open Microsoft.FSharp.Reflection
        let map5<'a,'b,'c,'d,'e,'z> (f:'a->'a) (g:'b->'b) (h:'c->'c) (i:'d->'d) (j:'e->'e) (src:'z) =
            let ft = typeof<'a>
            let gt = typeof<'b>
            let ht = typeof<'c>
            let it = typeof<'d>
            let jt = typeof<'e>
    
            let rec drill (o:obj) =
                if o = null then
                    (None, fun _ -> o)
                else
                    let ot = o.GetType()
                    if FSharpType.IsUnion(ot) then
                        let tag = match List.tryFind (fst >> ot.Equals) unionTagReaders with
                                      | Some (_, reader) ->
                                           reader o
                                      | None ->
                                           let newReader = FSharpValue.PreComputeUnionTagReader(ot)
                                           unionTagReaders <- (ot, newReader)::unionTagReaders
                                           newReader o
                        let info = match List.tryFind (fst >> ot.Equals) unionCaseInfos with
                                       | Some (_, caseInfos) ->
                                           Array.get caseInfos tag
                                       | None ->
                                           let newCaseInfos = FSharpType.GetUnionCases(ot)
                                           unionCaseInfos <- (ot, newCaseInfos)::unionCaseInfos
                                           Array.get newCaseInfos tag
                        let vals = match List.tryFind (fun ((tau, tag'), _) -> ot.Equals tau && tag = tag') unionReaders with
                                       | Some (_, reader) ->
                                           reader o
                                       | None ->
                                           let newReader = FSharpValue.PreComputeUnionReader info
                                           unionReaders <- ((ot, tag), newReader)::unionReaders
                                           newReader o
    //                    (Some(vals), FSharpValue.MakeUnion(info, Array.map traverse vals))
                        (Some(vals), (fun x -> FSharpValue.MakeUnion(info, x)))
                    elif FSharpType.IsTuple(ot) then
                        let fields = match List.tryFind (fst >> ot.Equals) tupleReaders with
                                         | Some (_, reader) ->
                                             reader o
                                         | None ->
                                             let newReader = FSharpValue.PreComputeTupleReader(ot)
                                             tupleReaders <- (ot, newReader)::tupleReaders
                                             newReader o
    //                    (FSharpValue.MakeTuple(Array.map traverse fields, ot)
                        (Some(fields), (fun x -> FSharpValue.MakeTuple(x, ot)))
                    elif FSharpType.IsRecord(ot) then
                        let fields = match List.tryFind (fst >> ot.Equals) recordReaders with
                                         | Some (_, reader) ->
                                             reader o
                                         | None ->
                                             let newReader = FSharpValue.PreComputeRecordReader(ot)
                                             recordReaders <- (ot, newReader)::recordReaders
                                             newReader o
    //                    FSharpValue.MakeRecord(ot, Array.map traverse fields)
                        (Some(fields), (fun x -> FSharpValue.MakeRecord(ot, x)))
                    else
                        (None, (fun _ -> o))
    
    
    
            and traverse (o:obj) cont =
                let parent =
                    if o = null then
                        o
                    else
                        let ot = o.GetType()
                        if ft = ot || ot.IsSubclassOf(ft) then
                            f (o :?> 'a) |> box
                        elif gt = ot || ot.IsSubclassOf(gt) then
                            g (o :?> 'b) |> box
                        elif ht = ot || ot.IsSubclassOf(ht) then
                            h (o :?> 'c) |> box
                        elif it = ot || ot.IsSubclassOf(it) then
                            i (o :?> 'd) |> box
                        elif jt = ot || ot.IsSubclassOf(jt) then
                            j (o :?> 'e) |> box
                        else
                            o
                let child, f = drill parent
    
                match child with 
                    | None -> 
                        f [||] |> cont
                    | Some(x) -> 
    
                        match x.Length with
                            | len when len > 1 ->
                                let resList = System.Collections.Generic.List<obj>()
                                let continuation = Array.foldBack (fun t s -> (fun mC -> resList.Add(mC); traverse t s) ) 
                                                                  (x.[1..]) 
                                                                  (fun mC -> resList.Add(mC); resList.ToArray() |> f |> cont)
                                traverse (x.[0]) continuation
                            | _ -> traverse x (fun mC -> 
                                                match mC with
                                                    | :? (obj[]) as mC -> f mC |> cont
                                                    | _ -> f [|mC|] |> cont
                                              )
    
            traverse src (fun x -> x) |> unbox : 'z
    

    You should build this with enabled Generate tail calls option (by default, this option disabled in Debug mode, but enabled in Release).

    Example:

    type A1 =
        | A of A2
        | B of int
    
    and A2 =
        | A of A1
        | B of int
    
    and Root = 
        | A1 of A1
        | A2 of A2
    
    [<EntryPoint>]
    let main args =
        let rec build (elem: Root) n = 
            if n = 0 then elem
            else 
                match elem with
                    | A1(x) -> build (Root.A2(A2.A(x))) (n-1)
                    | A2(x) -> build (Root.A1(A1.A(x))) (n-1)
        let tree = build (Root.A1(A1.B(2))) 100000
    
        let a = map5 (fun x -> x) (fun x -> x) (fun x -> x) (fun x -> x) (fun x -> x) tree
        printf "%A" a
        0
    

    This code finished without Stack Overflow exception.