The following error occurs:
org.apache.spark.SparkException: Job aborted due to stage failure: Task serialization failed: java.lang.StackOverflowError
When I am performing a tail-recursive function using Scala. I was under the impression that tail recursion in Scala can't overflow and that was one of the strengths of Scala.
Method below:
def gdAll(inputRDD : RDD[(Int, Vector, Int, Vector, Double)]) : RDD[(Int, Vector, Int, Vector, Double)] = {
val step = 0.0000055
val h4 = 0.05
val errors = inputRDD.map { case (itemid, itemVector, userid, userVector, rating) =>
(itemid, itemVector, userid, userVector, rating, ((rating - userVector.dot(itemVector)) * itemVector) - h4 * userVector)
}.cache
val currentRMSE = sqrt(errors.aggregate(0.0)((accum, rating) => accum + pow(rating._5 - rating._4.dot(rating._2), 2), _ + _) / errors.count)
val totalUserError = errors.aggregate(Vector(0.0, 0.0))((accum, error) => accum + error._6, _+_)
val usersByKey = errors.map { case (itemid, itemVector, userid, userVector, rating, error) =>
(userid, (userVector, itemid, itemVector, rating, error))
}
val updatedUserFactors = usersByKey.map { case ((userid, (userVector, itemid, itemVector, rating, error))) =>
(itemid, itemVector, userid, userVector + (step * totalUserError), rating)
}
val fullyUpdatedUserFactors = updatedUserFactors.map{ case ((itemid, itemVector, userid, userVector, rating)) =>
(itemid, itemVector, userid, userVector, rating, ((rating - userVector.dot(itemVector)) * userVector) - h4 * itemVector)}
val itemsByKey = fullyUpdatedUserFactors.map { case (itemid, itemVector, userid, userVector, rating, error) =>
(itemid, (itemVector, userid, userVector, rating, error))
}
val totalItemError = fullyUpdatedUserFactors.aggregate(Vector(0.0, 0.0))((accum, error) => accum + error._6, _+_)
val updatedItemFactors = itemsByKey.map { case (itemid, (itemVector, userid, userVector, rating, error)) =>
(itemid, itemVector + (step * totalItemError), userid, userVector, rating) // totalItemError = itemError
}
val newRMSE = sqrt(updatedItemFactors.aggregate(0.0)((accum, rating) => accum + pow(rating._5 - rating._4.dot(rating._2), 2), _ + _) / errors.count)
println("Original RMSE: " + currentRMSE + " New RMSE: " + newRMSE)
val changeInRMSE = (newRMSE - currentRMSE).abs
if (changeInRMSE < 0.0000005) {
return updatedItemFactors
}
errors.unpersist()
gdAll(updatedItemFactors) // repeat if change is still large
}
Any ideas? Thank you.
This is the subject of a Spark Summit East 2015 talk, Experience and Lessons Learned for Large-Scale Graph Analysis using GraphX.
What happens is that with each iteration the RDD lineage grows. The lineage is serialized recursively, so at some point this causes a StackOverflowError
.
Possible workarounds are:
-Xss
).RDD.checkpoint
. (The talk includes details about why this is not a simple fix.)