Search code examples
clojure

(Another) Stack overflow on loop-recur in Clojure


Similar questions: One, Two, Three.

I am thoroughly flummoxed here. I'm using the loop-recur form, I'm using doall, and still I get a stack overflow for large loops. My Clojure version is 1.5.1.

Context: I'm training a neural net to mimic XOR. The function xor is the feed-forward function, taking weights and input and returning the result; the function b-xor is the back-propagation function that returns updated weights given the results of the last call to xor.

The following loop runs just fine, runs very fast, and returns a result, and based off of the results it returns, it is training the weights perfectly:

(loop [res 1        ; <- initial value doesn't matter
       weights xorw ; <- initial pseudo-random weights
       k 0]         ; <- count
  (if (= k 1000000)
      res
      (let [n (rand-int 4)
            r (doall (xor weights (first (nth xorset n))))]
        (recur (doall r)
               (doall (b-xor weights r (second (nth xorset n))))
               (inc k)))))

But of course, that only gives me the result of the very last run. Obviously I want to know what weights have been trained to get that result! The following loop, with nothing but the return value changed, overflows:

(loop [res 1
       weights xorw
       k 0]
  (if (= k 1000000)
      weights              ; <- new return value
      (let [n (rand-int 4)
            r (doall (xor weights (first (nth xorset n))))]
        (recur (doall r)
               (doall (b-xor weights r (second (nth xorset n))))
               (inc k)))))

This doesn't make sense to me. The entirety of weights gets used in each call to xor. So why could I use weights internally but not print it to the REPL?

And as you can see, I've stuck doall in all manner of places, more than I think I should need. XOR is a toy example, so weights and xorset are both very small. I believe the overflow occurs not from the execution of xor and b-xor, but when the REPL tries to print weights, for these two reasons:

(1) this loop can go up to 1500 without overflowing the stack.

(2) the time the loop runs is consistent with the length of the loop; that is, if I loop to 5000, it runs for half a second and then prints a stack overflow; if I loop to 1000000, it runs for ten seconds and then prints a stack overflow -- again, only if I print weights and not res at the end.

(3) EDIT: Also, if I just wrap the loop in (def w ... ), then there is no stack overflow. Attempting to peek at the resulting variable does, though.

user=> (clojure.stacktrace/e)
java.lang.StackOverflowError: null
 at clojure.core$seq.invoke (core.clj:133)
    clojure.core$map$fn__4211.invoke (core.clj:2490)
    clojure.lang.LazySeq.sval (LazySeq.java:42)
    clojure.lang.LazySeq.seq (LazySeq.java:60)
    clojure.lang.RT.seq (RT.java:484)
    clojure.core$seq.invoke (core.clj:133)
    clojure.core$map$fn__4211.invoke (core.clj:2490)
    clojure.lang.LazySeq.sval (LazySeq.java:42)
nil

Where is the lazy sequence?

If you have suggestions for better ways to do this (this is just my on-the-fly REPL code), that'd be great, but I'm really looking for an explanation as to what is happening in this case.


EDIT 2: Definitely (?) a problem with the REPL.

This is bizarre. weights is a list containing six lists, four of which are empty. So far, so good. But trying to print one of these empty lists to the screen results in a stack overflow, but only the first time. The second time it prints without throwing any errors. Printing the non-empty lists produces no stack overflow. Now I can move on with my project, but...what on earth is going on here? Any ideas? (Please pardon the following ugliness, but I thought it might be helpful)

user=> (def ww (loop etc. etc. ))
#'user/ww
user=> (def x (first ww))
#'user/x
user=> x
StackOverflowError   clojure.lang.RT.seq (RT.java:484)
user=> x
()
user=> (def x (nth ww 3))
#'user/x
user=> x
(8.47089879874061 -8.742792338501289 -4.661609290853221)
user=> (def ww (loop etc. etc. ))
#'user/ww
user=> ww
StackOverflowError   clojure.core/seq (core.clj:133)
user=> ww
StackOverflowError   clojure.core/seq (core.clj:133)
user=> ww
StackOverflowError   clojure.core/seq (core.clj:133)
user=> ww
StackOverflowError   clojure.core/seq (core.clj:133)
user=> ww
(() () () (8.471553034351501 -8.741870954507117 -4.661171802683782) () (-8.861958958234174 8.828933147027938 18.43649480263751 -4.532462509591159))

Solution

  • If you call doall on a sequence that contains more lazy sequences, doall does not recursively iterate through the subsequences. In this particular case, the return value of b-xor contained empty lists that were defined lazily from previous empty lists defined lazily from previous empty lists, and so on. All I had to do was add a single doall to the map that produced the empty lists (in b-xor), and the problem disappeared. This loop (with all of the doall's removed) never overflows:

    (loop [res 1
           weights xorw
           k 0]
      (if (= k 1000000)
          weights 
          (let [n (rand-int 4)
                r (xor weights (first (nth xorset n)))]
            (recur r
                   (b-xor weights r (second (nth xorset n)))
                   (inc k)))))
    

    I hope this is helpful to some other poor soul who thought he'd solved his lazy sequencing issues with a badly-placed doall.