Search code examples
scalarecursionapache-sparkstack-overflowtail-recursion

Apache Spark in Scala Stackoverflow with tail recursion


The following error occurs:

org.apache.spark.SparkException: Job aborted due to stage failure: Task serialization failed: java.lang.StackOverflowError

When I am performing a tail-recursive function using Scala. I was under the impression that tail recursion in Scala can't overflow and that was one of the strengths of Scala.

Method below:

  def gdAll(inputRDD : RDD[(Int, Vector, Int, Vector, Double)]) : RDD[(Int, Vector, Int, Vector, Double)] = {

val step = 0.0000055
val h4 = 0.05

val errors = inputRDD.map { case (itemid, itemVector, userid, userVector, rating) =>
  (itemid, itemVector, userid, userVector, rating, ((rating - userVector.dot(itemVector)) * itemVector) - h4 * userVector)
}.cache

val currentRMSE = sqrt(errors.aggregate(0.0)((accum, rating) => accum + pow(rating._5 - rating._4.dot(rating._2), 2), _ + _) / errors.count)

val totalUserError = errors.aggregate(Vector(0.0, 0.0))((accum, error) => accum + error._6, _+_)

val usersByKey = errors.map { case (itemid, itemVector, userid, userVector, rating, error) =>
  (userid, (userVector, itemid, itemVector, rating, error))
}

val updatedUserFactors = usersByKey.map { case ((userid, (userVector, itemid, itemVector, rating, error))) =>

  (itemid, itemVector, userid, userVector + (step * totalUserError), rating)
}

val fullyUpdatedUserFactors = updatedUserFactors.map{ case ((itemid, itemVector, userid, userVector, rating)) =>
  (itemid, itemVector, userid, userVector, rating, ((rating - userVector.dot(itemVector)) * userVector) - h4 * itemVector)}

val itemsByKey = fullyUpdatedUserFactors.map { case (itemid, itemVector, userid, userVector, rating, error) =>
  (itemid, (itemVector, userid, userVector, rating, error))
}

val totalItemError = fullyUpdatedUserFactors.aggregate(Vector(0.0, 0.0))((accum, error) => accum + error._6, _+_)

val updatedItemFactors = itemsByKey.map { case (itemid, (itemVector, userid, userVector, rating, error)) =>

  (itemid, itemVector + (step * totalItemError), userid, userVector, rating) // totalItemError = itemError
}

val newRMSE = sqrt(updatedItemFactors.aggregate(0.0)((accum, rating) => accum + pow(rating._5 - rating._4.dot(rating._2), 2), _ + _) / errors.count)

println("Original RMSE: " + currentRMSE + " New RMSE: " + newRMSE)

val changeInRMSE = (newRMSE - currentRMSE).abs

if (changeInRMSE < 0.0000005) {

  return updatedItemFactors
}

errors.unpersist()

gdAll(updatedItemFactors) // repeat if change is still large

 }

Any ideas? Thank you.


Solution

  • This is the subject of a Spark Summit East 2015 talk, Experience and Lessons Learned for Large-Scale Graph Analysis using GraphX.

    What happens is that with each iteration the RDD lineage grows. The lineage is serialized recursively, so at some point this causes a StackOverflowError.

    Possible workarounds are:

    • Stop the iteration before this happens.
    • Allocate larger stack (-Xss).
    • Checkpoint the RDD with RDD.checkpoint. (The talk includes details about why this is not a simple fix.)
    • Just write out the RDD to disk and read it back.