Search code examples
f#treeimmutabilitymutable

How would you convert this mutable tree into an immutable one?


How would you convert type Node into an immutable tree?

This class implements a range tree that does not allow overlapping or adjacent ranges and instead joins them. For example if the root node is {min = 10; max = 20} then it's right child and all its grandchildren must have a min and max value greater than 21. The max value of a range must be greater than or equal to the min. I included a test function so you can run this as is and it will dump out any cases that fail.

I recommend starting with the Insert method to read this code.

module StackOverflowQuestion

open System

type Range = 
    { min : int64; max : int64 }
with
    override this.ToString() =
        sprintf "(%d, %d)" this.min this.max

[<AllowNullLiteralAttribute>]
type Node(left:Node, right:Node, range:Range) =
    let mutable left = left
    let mutable right = right
    let mutable range = range


    // Symmetric to clean right
    let rec cleanLeft(node : Node) =
        if node.Left = null then
            ()
        elif range.max < node.Left.Range.min - 1L then 
            cleanLeft(node.Left)
        elif range.max <= node.Left.Range.max then
            range <- {min = range.min; max = node.Left.Range.max}
            node.Left <- node.Left.Right
        else
            node.Left <- node.Left.Right
            cleanLeft(node)

    // Clean right deals with merging when the node to merge with is not on the 
    // left outside of the tree.  It travels right inside the tree looking for an 
    // overlapping node.  If it finds one it merges the range and replaces the 
    // node with its left child thereby deleting it.  If it finds a subset node
    // it replaces it with its left child, checks it and continues looking right.
    let rec cleanRight(node : Node) =
        if node.Right = null then
            ()
        elif range.min > node.Right.Range.max + 1L then
            cleanRight(node.Right)
        elif range.min >= node.Right.Range.min then
            range <- {min = node.Right.Range.min; max = range.max}
            node.Right <- node.Right.Left
        else 
            node.Right <- node.Right.Left
            cleanRight(node)

    // Merger left is called whenever the min value of a node decreases.
    // It handles the case of left node overlap/subsets and merging/deleting them.
    // When no more overlaps are found on the left nodes it calls clean right.
    let rec mergeLeft(node : Node) =
        if node.Left = null then
            ()
        elif range.min <= node.Left.Range.min - 1L then 
            node.Left <- node.Left.Left
            mergeLeft(node)
        elif range.min <= node.Left.Range.max + 1L then
            range <- {min = node.Left.Range.min; max = range.max}
            node.Left <- node.Left.Left
        else 
            cleanRight(node.Left)

    // Symmetric to merge left
    let rec mergeRight(node : Node) =
        if node.Right = null then
            ()
        elif range.max >= node.Right.Range.max + 1L then
            node.Right <- node.Right.Right
            mergeRight(node)
        elif range.max >= node.Right.Range.min - 1L then
            range <- {min = range.min; max = node.Right.Range.max}
            node.Right <- node.Right.Right
        else 
            cleanLeft(node.Right)


    let (|Before|After|BeforeOverlap|AfterOverlap|Superset|Subset|) r = 
        if r.min > range.max + 1L then After
        elif r.min >= range.min then
            if r.max <= range.max then Subset
            else AfterOverlap
        elif r.max < range.min - 1L then Before
        elif r.max <= range.max then
            if r.min >= range.min then Subset
            else BeforeOverlap
        else Superset

    member this.Insert r = 
        match r with
        | After ->
            if right = null then
                right <- Node(null, null, r)
            else
                right.Insert(r)
        | AfterOverlap ->
            range <- {min = range.min; max = r.max}
            mergeRight(this)
        | Before -> 
            if left = null then
                left <- Node(null, null, r)
            else
                left.Insert(r)
        | BeforeOverlap -> 
            range <- {min = r.min; max = range.max}
            mergeLeft(this)
        | Superset ->
            range <- r
            mergeLeft(this)
            mergeRight(this)
        | Subset -> ()

    member this.Left with get() : Node = left and set(x) = left <- x
    member this.Right with get() : Node = right and set(x) = right <- x
    member this.Range with get() : Range = range and set(x) = range <- x

    static member op_Equality (a : Node, b : Node) =
        a.Range = b.Range

    override this.ToString() =
        sprintf "%A" this.Range

type RangeTree() =
    let mutable root = null

    member this.Add(range) =
        if root = null then
            root <- Node(null, null, range)
        else
            root.Insert(range)

    static member fromArray(values : Range seq) =
        let tree = new RangeTree()
        values |> Seq.iter (fun value -> tree.Add(value))
        tree

    member this.Seq 
        with get() =
            let rec inOrder(node : Node) =
                seq {
                    if node <> null then
                        yield! inOrder node.Left
                        yield {min = node.Range.min; max = node.Range.max}
                        yield! inOrder node.Right
                }
            inOrder root

let TestRange() =
    printf "\n"

    let source(n) = 
        let rnd = new Random(n)
        let rand x = rnd.NextDouble() * float x |> int64
        let rangeRnd() =
            let a = rand 1500
            {min = a; max = a + rand 15}
        [|for n in 1 .. 50 do yield rangeRnd()|]

    let shuffle n (array:_[]) =
        let rnd = new Random(n)
        for i in 0 .. array.Length - 1 do
            let n = rnd.Next(i)
            let temp = array.[i]
            array.[i] <- array.[n]
            array.[n] <- temp
        array

    let testRangeAdd n i =
        let dataSet1 = source (n+0)
        let dataSet2 = source (n+1)
        let dataSet3 = source (n+2)
        let result1 = Array.concat [dataSet1; dataSet2; dataSet3] |> shuffle (i+3) |> RangeTree.fromArray 
        let result2 = Array.concat [dataSet2; dataSet3; dataSet1] |> shuffle (i+4) |> RangeTree.fromArray 
        let result3 = Array.concat [dataSet3; dataSet1; dataSet2] |> shuffle (i+5) |> RangeTree.fromArray 
        let test1 = (result1.Seq, result2.Seq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 
        let test2 = (result2.Seq, result3.Seq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 
        let test3 = (result3.Seq, result1.Seq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 

        let print dataSet =
            dataSet |> Seq.iter (fun r -> printf "%s " <| string r)

        if not (test1 && test2 && test3) then
            printf "\n\nTest# %A: " n
            printf "\nSource 1: %A: " (n+0)
            dataSet1 |> print
            printf "\nSource 2: %A: " (n+1)
            dataSet2 |> print
            printf "\nSource 3: %A: " (n+2)
            dataSet3 |> print
            printf "\nResult 1: %A: " (n+0)
            result1.Seq |> print
            printf "\nResult 2: %A: " (n+1)
            result2.Seq |> print
            printf "\nResult 3: %A: " (n+2)
            result3.Seq |> print
            ()

    for i in 1 .. 10 do
        for n in 1 .. 1000 do
            testRangeAdd n i
        printf "\n%d" (i * 1000)

    printf "\nDone"

TestRange()

System.Console.ReadLine() |> ignore

Test cases for Range

After         (11, 14)      |   | <-->
AfterOverlap  (10, 14)      |   |<--->
AfterOverlap  ( 9, 14)      |   +---->
AfterOverlap  ( 6, 14)      |<--+---->
 "Test Case"  ( 5,  9)      +---+
BeforeOverlap ( 0,  8) <----+-->|
BeforeOverlap ( 0,  5) <----+   |
BeforeOverlap ( 0,  4) <--->|   |
Before        ( 0,  3) <--> |   |
Superset      ( 4, 10)     <+---+>
Subset        ( 5,  9)      +---+
Subset        ( 6,  8)      |<->|

This is not an answer.

I adapted my test case to run against Juliet's code. It fails on a number of cases however I do see it passing some test.

type Range = 
    { min : int64; max : int64 }
with
    override this.ToString() =
        sprintf "(%d, %d)" this.min this.max

let rangeSeqToJTree ranges =
    ranges |> Seq.fold (fun tree range -> tree |> insert (range.min, range.max)) Nil

let JTreeToRangeSeq node =
    let rec inOrder node =
        seq {
            match node with
            | JNode(left, min, max, right) ->
                yield! inOrder left
                yield {min = min; max = max}
                yield! inOrder right
            | Nil -> ()
        }
    inOrder node

let TestJTree() =
    printf "\n"

    let source(n) = 
        let rnd = new Random(n)
        let rand x = rnd.NextDouble() * float x |> int64
        let rangeRnd() =
            let a = rand 15
            {min = a; max = a + rand 5}
        [|for n in 1 .. 5 do yield rangeRnd()|]

    let shuffle n (array:_[]) =
        let rnd = new Random(n)
        for i in 0 .. array.Length - 1 do
            let n = rnd.Next(i)
            let temp = array.[i]
            array.[i] <- array.[n]
            array.[n] <- temp
        array

    let testRangeAdd n i =
        let dataSet1 = source (n+0)
        let dataSet2 = source (n+1)
        let dataSet3 = source (n+2)
        let result1 = Array.concat [dataSet1; dataSet2; dataSet3] |> shuffle (i+3) |> rangeSeqToJTree
        let result2 = Array.concat [dataSet2; dataSet3; dataSet1] |> shuffle (i+4) |> rangeSeqToJTree
        let result3 = Array.concat [dataSet3; dataSet1; dataSet2] |> shuffle (i+5) |> rangeSeqToJTree
        let test1 = (result1 |> JTreeToRangeSeq, result2 |> JTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 
        let test2 = (result2 |> JTreeToRangeSeq, result3 |> JTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 
        let test3 = (result3 |> JTreeToRangeSeq, result1 |> JTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 

        let print dataSet =
            dataSet |> Seq.iter (fun r -> printf "%s " <| string r)

        if not (test1 && test2 && test3) then
            printf "\n\nTest# %A: " n
            printf "\nSource 1: %A: " (n+0)
            dataSet1 |> print
            printf "\nSource 2: %A: " (n+1)
            dataSet2 |> print
            printf "\nSource 3: %A: " (n+2)
            dataSet3 |> print
            printf "\n\nResult 1: %A: " (n+0)
            result1 |> JTreeToRangeSeq |> print
            printf "\nResult 2: %A: " (n+1)
            result2 |> JTreeToRangeSeq |> print
            printf "\nResult 3: %A: " (n+2)
            result3 |> JTreeToRangeSeq |> print
            ()

    for i in 1 .. 1 do
        for n in 1 .. 10 do
            testRangeAdd n i
        printf "\n%d" (i * 10)

    printf "\nDone"

TestJTree()

Solution

  • Got it working! I think the hardest part was figuring out how to make recursive calls on children while passing state back up the stack.

    Performance is rather interesting. When inserting mainly ranges that collide and get merged together the mutable version is faster while if you insert mainly none overlapping nodes and fill out the tree the immutable version is faster. I've seen performance swing a max of 100% both ways.

    Here's the complete code.

    module StackOverflowQuestion
    
    open System
    
    type Range = 
        { min : int64; max : int64 }
    with
        override this.ToString() =
            sprintf "(%d, %d)" this.min this.max
    
    type RangeTree =
        | Node of RangeTree * int64 * int64 * RangeTree
        | Nil
    
    // Clean right deals with merging when the node to merge with is not on the 
    // left outside of the tree.  It travels right inside the tree looking for an 
    // overlapping node.  If it finds one it merges the range and replaces the 
    // node with its left child thereby deleting it.  If it finds a subset node
    // it replaces it with its left child, checks it and continues looking right.
    let rec cleanRight n node =
        match node with
        | Node(left, min, max, (Node(left', min', max', right') as right)) -> 
            if n > max' + 1L then
                let node, n' = right |> cleanRight n
                Node(left, min, max, node), n'
            elif n >= min' then
                Node(left, min, max, left'), min'
            else 
                Node(left, min, max, left') |> cleanRight n
        | _ -> node, n
    
    // Symmetric to clean right
    let rec cleanLeft x node =
        match node with
        | Node(Node(left', min', max', right') as left, min, max, right) -> 
            if x < min' - 1L then
                let node, x' = left |> cleanLeft x
                Node(node, min, max, right), x'
            elif x <= max' then 
                Node(right', min, max, right), max'
            else 
                Node(right', min, max, right) |> cleanLeft x
            | Nil -> node, x
        | _ -> node, x
    
    // Merger left is called whenever the min value of a node decreases.
    // It handles the case of left node overlap/subsets and merging/deleting them.
    // When no more overlaps are found on the left nodes it calls clean right.
    let rec mergeLeft n node =
        match node with
        | Node(Node(left', min', max', right') as left, min, max, right) -> 
            if n <= min' - 1L then
                Node(left', min, max, right) |> mergeLeft n
            elif n <= max' + 1L then
                Node(left', min', max, right)
            else
                let node, min' = left |> cleanRight n
                Node(node, min', max, right)
        | _ -> node
    
    // Symmetric to merge left
    let rec mergeRight x node =
        match node with
        | Node(left, min, max, (Node(left', min', max', right') as right)) -> 
            if x >= max' + 1L then 
                Node(left, min, max, right') |> mergeRight x
            elif x >= min' - 1L then 
                Node(left, min, max', right')
            else 
                let node, max' = right |> cleanLeft x
                Node(left, min, max', node)
        | node -> node
    
    let (|Before|After|BeforeOverlap|AfterOverlap|Superset|Subset|) (min, max, min', max') = 
        if min > max' + 1L then After
        elif min >= min' then
            if max <= max' then Subset
            else AfterOverlap
        elif max < min' - 1L then Before
        elif max <= max' then
            if min >= min' then Subset
            else BeforeOverlap
        else Superset
    
    let rec insert min' max' this = 
        match this with
        | Node(left, min, max, right) ->
            match (min', max', min, max) with
            | After         -> Node(left, min, max, right |> insert min' max')
            | AfterOverlap  -> Node(left, min, max', right) |> mergeRight max'
            | Before        -> Node(left |> insert min' max', min, max, right)
            | BeforeOverlap -> Node(left, min', max, right) |> mergeLeft min'
            | Superset      -> Node(left, min', max', right) |> mergeLeft min' |> mergeRight max'
            | Subset        -> this
        | Nil -> Node(Nil, min', max', Nil)
    
    let rangeSeqToRangeTree ranges =
        ranges |> Seq.fold (fun tree range -> tree |> insert range.min range.max) Nil
    
    let rangeTreeToRangeSeq node =
        let rec inOrder node =
            seq {
                match node with
                | Node(left, min, max, right) ->
                    yield! inOrder left
                    yield {min = min; max = max}
                    yield! inOrder right
                | Nil -> ()
            }
        inOrder node
    
    let TestImmutableRangeTree() =
        printf "\n"
    
        let source(n) = 
            let rnd = new Random(n)
            let rand x = rnd.NextDouble() * float x |> int64
            let rangeRnd() =
                let a = rand 15000
                {min = a; max = a + rand 150}
            [|for n in 1 .. 200 do yield rangeRnd()|]
    
        let shuffle n (array:_[]) =
            let rnd = new Random(n)
            for i in 0 .. array.Length - 1 do
                let n = rnd.Next(i)
                let temp = array.[i]
                array.[i] <- array.[n]
                array.[n] <- temp
            array
    
        let print dataSet =
            dataSet |> Seq.iter (fun r -> printf "%s " <| string r)
    
        let testRangeAdd n i =
            let dataSet1 = source (n+0)
            let dataSet2 = source (n+1)
            let dataSet3 = source (n+2)
            let result1 = Array.concat [dataSet1; dataSet2; dataSet3] |> shuffle (i+3) |> rangeSeqToRangeTree
            let result2 = Array.concat [dataSet2; dataSet3; dataSet1] |> shuffle (i+4) |> rangeSeqToRangeTree
            let result3 = Array.concat [dataSet3; dataSet1; dataSet2] |> shuffle (i+5) |> rangeSeqToRangeTree
            let test1 = (result1 |> rangeTreeToRangeSeq, result2 |> rangeTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 
            let test2 = (result2 |> rangeTreeToRangeSeq, result3 |> rangeTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 
            let test3 = (result3 |> rangeTreeToRangeSeq, result1 |> rangeTreeToRangeSeq) ||> Seq.forall2 (fun a b -> a.min = b.min && a.max = b.max) 
    
            if not (test1 && test2 && test3) then
                printf "\n\nTest# %A: " n
                printf "\nSource 1: %A: " (n+0)
                dataSet1 |> print
                printf "\nSource 2: %A: " (n+1)
                dataSet2 |> print
                printf "\nSource 3: %A: " (n+2)
                dataSet3 |> print
                printf "\n\nResult 1: %A: " (n+0)
                result1 |> rangeTreeToRangeSeq |> print
                printf "\nResult 2: %A: " (n+1)
                result2 |> rangeTreeToRangeSeq |> print
                printf "\nResult 3: %A: " (n+2)
                result3 |> rangeTreeToRangeSeq |> print
                ()
    
        for i in 1 .. 10 do
            for n in 1 .. 100 do
                testRangeAdd n i
            printf "\n%d" (i * 10)
    
        printf "\nDone"
    
    TestImmutableRangeTree()
    
    System.Console.ReadLine() |> ignore