I'm trying to use spark LBFGS method in my project recently, but when I read source code ,I realy got a big problem, here is the code: the code I don't understand and here is the source code link:https://github.com/apache/spark/blob/v1.6.0/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
My problem is : If my input data(label, feature) only contains label and feature vectors, how does treeAggregate seqOp able to match{ case ((grad, loss), (label, features)}? I thought it can only match {case (label, features)}。
And in fact ,I'm not really understand "treeAggregate" either, can someone help me ?
I think that you are not really understood treeAggregate operation
.
In you post picture, let me give you a thorough description
about your problem.
After that you will understand why the source code can match the things correctly!
If you find treeAggregate
confusing, you can first understand the simple but similar version of it - aggregate
.
The prototype of aggregate is:
def aggregate[U](zeroValue: U)(seqOp: (U, T) ⇒ U, combOp: (U, U) ⇒ U)(implicit arg0: ClassTag[U]): U
It seems complex, right? Let me clarify it for you:
RDD provides abstraction for physically distributed data in many partitions, so how can we aggregate the values for one specific key?
It obviously has two situations:
merge one value in same partition.
merge data across different partitions.
seqOp: (U, V) ⇒ U
This is exactly operation how can values be merged in the one partition result.
combOp: (U, U) ⇒ U
This is across partitions merge operation!
I guess you are familiar with the reduce
operation.
In fact, aggregate
operation is more general that reduce
operation.
Why aggregate
exists is that sometimes we need to "reduce" the values for a unique key, but want to get the result in a different type from they are in parent rdd.
For example, what if we want to find for one specific key in parent rdd, how many unique values with it?
this "reduce" operation's value type is obviously different from the parent rdd.
val pairs = sc.parallelize(Array(("a", 3), ("a", 1), ("b", 7), ("a", 5)))
val sets = pairs.aggregateByKey(new HashSet[Int])(_+_, _++_)
sets.collect
res0: Array[(String, scala.collection.mutable.HashSet[Int])] =Array((b,Set(7)), (a,Set(1, 5, 3))
The example is about aggregateByKey, but it is understandable for aggregate, just using whole data aggregate, not about different key.
That's all.