Search code examples
haskelltreehaskell-lenslenses

Traversing and adding elements to a Data.Tree using Lenses in Haskell


I'm starting to use lenses and until now I've been unable to use them in a concrete part of a codebase I'm writing. My objective is to update a rose tree structure such as the one in Data.Tree by adding a new node inside one of the existing ones. To do so I thought that it would make sense to identify each node with a unique id, so it would look like that:

type MyTree = Tree Id
type Path = [Id]

addToTree :: MyTree -> MyTree -> Path -> MyTree
addToTree originalTree newNode path = undefined

The function addToTree would have to traverse the originalTree by following the path of ids and add the newNode at that level, returning the whole updated tree. I haven't had problems to make a getter for that but I'm not being able to find an appropriate lens to perform the operation with.

That's what I've got until now:

import           Control.Lens
import           Data.Tree
import           Data.Tree.Lens

addToTree :: MyTree -> Path -> MyTree -> MyTree
addToTree tree path branch = tree & (traversalPath path) . branches %~ (branch:)

traversalPath :: (Foldable t, Applicative f, Contravariant f) => t Id -> (MyTree -> f MyTree) -> MyTree -> f MyTree
traversalPath = foldl (\acc id-> acc . childTraversal id) id

childTraversal :: (Indexable Int p, Applicative f) => Id -> p MyTree (f MyTree) -> MyTree -> f MyTree
childTraversal id = branches . traversed . withId id

withId :: (Choice p, Applicative f) => Id -> Optic' p f MyTree MyTree
withId id = filtered (\x -> rootLabel x == id)

But it fails to compile with:

• No instance for (Contravariant Identity)
    arising from a use of ‘traversalPath’
• In the first argument of ‘(.)’, namely ‘(traversalPath path)’
  In the first argument of ‘(%~)’, namely
    ‘(traversalPath path) . branches’
  In the second argument of ‘(&)’, namely
    ‘(traversalPath path) . branches %~ (branch :)’

Thanks!


Solution

  • This is not specially elegant, but should do the trick:

    import Control.Lens
    import Data.Monoid (Endo(..)) -- A tidier idiom for 'foldr (.) id'.
    import Data.List.NonEmpty (NonEmpty(..)) -- You don't want an empty path.
    import qualified Data.List.NonEmpty as N
    import Data.Tree
    import Data.Tree.Lens -- That's where I got 'branches' from.
    
    addToTree :: Eq a => NonEmpty a -> Tree a -> Tree a -> Tree a
    addToTree path newNode oldTree = head $ over pathForests (newNode :) [oldTree]
        where
        pathForests = appEndo $ foldMap (Endo . goDown) path 
        goDown x = traverse . filtered ((x ==) . rootLabel) . branches
    

    (In particular, I never like using head, even in cases like this one in which it can't possibly fail. Feel free to replace it with your favourite circumlocution.)

    Demo:

    GHCi> addToTree (1 :| []) (Node 2 []) (Node 1 [])
    Node {rootLabel = 1, subForest = [Node {rootLabel = 2, subForest = []}]}
    GHCi> addToTree (4 :| []) (Node 2 []) (Node 1 [])
    Node {rootLabel = 1, subForest = []}
    GHCi> addToTree (1 :| [5]) (Node 2 []) (Node 1 [Node 5 [], Node 6 []])
    Node {rootLabel = 1, subForest = [Node {rootLabel = 5, subForest = [Node {rootLabel = 2, subForest = []}]},Node {rootLabel = 6, subForest = []}]}
    GHCi> addToTree (1 :| [7]) (Node 2 []) (Node 1 [Node 5 [], Node 6 []])
    Node {rootLabel = 1, subForest = [Node {rootLabel = 5, subForest = []},Node {rootLabel = 6, subForest = []}]}
    GHCi> addToTree (1 :| [5,3]) (Node 2 []) (Node 1 [Node 5 [], Node 6 []])
    Node {rootLabel = 1, subForest = [Node {rootLabel = 5, subForest = []},Node {rootLabel = 6, subForest = []}]}
    

    Do note that we are dealing with traversals, and not with lenses -- there is no guarantee or expectation that the target of the path exists or is unique.

    Here is a more stylised variant, without head and using alaf to handle the Endo wrapping.

    addToTree :: Eq a => NonEmpty a -> Tree a -> Tree a -> Tree a
    addToTree (desiredRoot :| path) newNode oldTree@(Node x ts)
        | x == desiredRoot = Node x (over pathForests (newNode :) ts)
        | otherwise = oldTree
        where
        pathForests = alaf Endo foldMap goDown path 
        goDown x = traverse . filtered ((x ==) . rootLabel) . branches