Search code examples
clojuretree

How to generate a list of all subtrees in Clojure using higher order functions?


Given a tree, how do you generate a list of all (proper) subtrees in Clojure using higher order functions?

Background

I am working on Advent of Code 2019 Problem #6. The problem begins with an adjacency list. I have represented the adjacency list as an n-ary tree, using Clojure lists, with the following structure.

A node that is not a leaf is a list with two parts: the first part is an element representing the root of that section of the tree; the second part is a n elements representing branches from the root. Leaves are lists having a keyword as their only element. Thus, I represent a tree of the form,

  B -- C
 /
A
 \
  D

with the following list:

(:A (:B (:C)) (:D))

Solution using Recursion

I want to list every proper subtree of a given tree. I know how to do this using recursion, as follows:

(defn subtrees
  [tree]
  (loop [trees tree
         results '()]
    (if (empty? trees)
      results
      (let [subtree #(if (keyword? (first %)) (rest %) nil)
            leaf? #(and (list %) (keyword? (first %)) (= (count %) 1))
            sub (subtree (first trees))]
        (if (every? leaf? sub)
          (recur (rest trees) (into results sub))
          (recur (into (rest trees) sub) (into results sub)))))))

So I do the work with trees and results: I begin with the tree in trees, and then add each subtree that is not one or more leaves into trees and results at each step (or: just into results if I have one or more leaves). This gives me a list of all proper subtrees of tree, which is the point of the function. Here is the working solution with very detailed comments and a bunch of test cases.

My Question

I should like to know how to accomplish the same using higher-order functions. What I would really like to do is use map and call the function recursively: at each stage, just call subtree on every element in the list. The problem I have encountered is that when I do this, I end up with a huge mess of parentheses and can't consistently drill down through the mess to get to the subtrees. Something like this:

(defn subt
  [trees]
  (let [subtree #(if (keyword? (first %)) (rest %) nil)
        leaf? #(and (list %) (keyword? (first %)) (= (count %) 1))
        sub (subtree trees)]
    (if (every? leaf? sub)
      nil
      (cons (map subt sub) trees))))

You can see the (map subt sub) is what I'm going for here, but I am running into a lot of difficulty using map, even though my sense is that is what I want for my higher-order function. I thought about using reduce as a stand-in for the loop in subtrees above; but because trees changes by subtrees being added, I don't think reduce is appropriate, at least with the loop as I have constructed it. I should say, also, that I'm not interested in a library to do the work; I want to know how to solve it using core functions. Thanks in advance.


Solution

  • Here is an attempt at computing all the subtrees using various functions from the standard library.

    (defn expand-subtrees [tree-set]
      (into #{} (comp (map rest) cat) tree-set))
    
    (defn all-subtrees [tree]
      (reduce into #{}
              (take-while seq (iterate expand-subtrees #{tree}))))
    

    and we can call it like this:

    (all-subtrees '(:A (:B (:C)) (:D)))
    ;; => #{(:D) (:B (:C)) (:C) (:A (:B (:C)) (:D))}
    

    The helper function expand-subtrees takes a set of trees and produces a new set of first-level subtrees of the input set. Then we use iterate with expand-subtrees, starting with the initial tree, to produce a sequence of expanded subtrees. We take elements from this sequence until there are no more subtrees. Then we merge all subtrees into a set, which is the result. Of course, you can disj the initial tree from that set if you desire.