Search code examples
f#computation-expression

F# Flattening Nested Tuples in Computation Expression


I have a computation expression which I want to return a flattened tuple as the first element and an int as the second. I am trying to use method overloading to accomplish this. Right now the compiler is throwing an error saying it cannot find a unique overload. I do not know how to help the compiler figure it out. It appears deterministic to me.

type IntBuilder () =

    member inline this.Yield (i:int) =
        i

    member inline this.For(source:seq<'a>, body:'a -> seq<'b * int>) =
        source
        |> Seq.collect (fun x -> body x |> Seq.map (fun (idx, i) -> (x, idx), i))

    member inline this.For(source:seq<'a>, body:'a -> int) =
        source |> Seq.map (fun x -> x, body x)

    member inline this.Run(source:seq<('a * ('b * ('c * 'd))) * 'v>) =
        source 
        |> Seq.map (fun ((x, (y, (z, a))), d) -> (x, y, z, a), d)

    member inline this.Run(source:seq<('a * ('b * 'c)) * 'v>) =
        source 
        |> Seq.map (fun ((x, (y, z)), d) -> (x, y, z), d)

    member inline this.Run(source:seq<('a * 'b) * 'v>) =
        source 
        |> Seq.map (fun ((x, y), d) -> (x, y), d)

    member inline this.Run(source:seq<'a * 'v>) =
        source 
        |> Seq.map (fun (x, d) -> x, d)

let intBuilder = IntBuilder ()
let c = 
    intBuilder {
        for i in 1..2 do
            for j in 1..2 do
                for k in 1..2 do
                    for l in 1..2 -> 
                         i + j + k + l
    }

// What I get
c : seq<(int * (int * (int * int))) * int>

// What I want
c : seq<(int * int * int * int) * int>

In this case c is of type seq<(int * (int * (int * int))) * int>. I want the IntBuilder computation to return seq<(int * int * int * int), int>. How do I make this happen?


Solution

  • Looks like the only way to get this to work is by wrapping them all in concrete types:

    type T1<'a> = | T1 of seq<'a * int>
    type T2<'a,'b> = | T2 of seq<('a * 'b) * int>
    type T3<'a,'b,'c> = | T3 of seq<('a * ('b * 'c)) * int>
    type T4<'a,'b,'c,'d> = | T4 of seq<('a * ('b * ('c * 'd))) * int>
    type T5<'a,'b,'c,'d,'e> = | T5 of seq<('a * ('b * ('c * ('d * 'e)))) * int>
    
    type IntBuilder () =
        member this.Yield (i:int) =
            i
    
        member this.For(source:seq<'a>, body:'a -> int) =
            source |> Seq.map (fun x -> x, body x)
            |> T1.T1
    
         member this.For(source:seq<'a>, body:'a -> T1<'b>) =
            source
            |> Seq.collect (fun x -> 
                body x 
                |> fun (T1.T1 x) -> x
                |> Seq.map (fun (idx, i) -> (x, idx), i))
            |> T2.T2
    
        member this.For(source:seq<'a>, body:'a -> T2<'b,'c>) =
           source
           |> Seq.collect (fun x -> 
               body x 
               |> fun (T2.T2 x) -> x
               |> Seq.map (fun (idx, i) -> (x, idx), i))
           |> T3.T3
    
        member this.For(source:seq<'a>, body:'a -> T3<'b,'c,'d>) =
           source
           |> Seq.collect (fun x -> 
               body x 
               |> fun (T3.T3 x) -> x
               |> Seq.map (fun (idx, i) -> (x, idx), i))
           |> T4.T4
    
        member this.For(source:seq<'a>, body:'a -> T4<'b,'c,'d,'e>) =
            source
            |> Seq.collect (fun x -> 
                body x 
                |> fun (T4.T4 x) -> x
                |> Seq.map (fun (idx, i) -> (x, idx), i))
            |> T5.T5
    
        member inline this.Run(T1.T1 source) =
            source 
            |> Seq.map (fun (x, d) -> x, d)
    
        member inline this.Run(T2.T2 source) =
            source 
            |> Seq.map (fun ((x, y), d) -> (x, y), d)
    
        member inline this.Run(T3.T3 source) =
            source 
            |> Seq.map (fun ((x, (y, z)), d) -> (x, y, z), d)
    
        member inline this.Run(T4.T4 source) =
            source 
            |> Seq.map (fun ((x, (y, (z, a))), d) -> (x, y, z, a), d)
    
    let intBuilder = IntBuilder ()
    
    
    let c = 
        intBuilder {
            for i in 1..2 do
                for j in 1..2 do
                    for k in 1..2 do
                        for l in 1..2 -> 
                             i + j + k + l
        }
    

    FSI output:

    val c : seq<(int * int * int * int) * int>
    
    > c;;
    val it : seq<(int * int * int * int) * int> =
      seq
        [((1, 1, 1, 1), 4); ((1, 1, 1, 2), 5); ((1, 1, 2, 1), 5);
         ((1, 1, 2, 2), 6); ...]