Search code examples
scalaapache-sparkapache-spark-mllib

Can't run LDA on Dataset[(scala.Long, org.apache.spark.mllib.linalg.Vector)] in Spark 2.0


I am following this tutorial video on LDA example and I'm getting the following issue :

<console>:37: error: overloaded method value run with alternatives:
  (documents: org.apache.spark.api.java.JavaPairRDD[java.lang.Long,org.apache.spark.mllib.linalg.Vector])org.apache.spark.mllib.clustering.LDAModel <and>
  (documents: org.apache.spark.rdd.RDD[(scala.Long, org.apache.spark.mllib.linalg.Vector)])org.apache.spark.mllib.clustering.LDAModel
  cannot be applied to (org.apache.spark.sql.Dataset[(scala.Long, org.apache.spark.mllib.linalg.Vector)])
     val model = run(lda_countVector)
                                   ^

So I want to convert this DF to RDD but it is always assigned as DataSet for me. Can anyone please look into this issue?

// Convert DF to RDD
import org.apache.spark.mllib.linalg.Vector
val lda_countVector = countVectors.map { case Row(id: Long, countVector: Vector) => (id, countVector) }
// import org.apache.spark.mllib.linalg.Vector
// lda_countVector: org.apache.spark.sql.Dataset[(Long, org.apache.spark.mllib.linalg.Vector)] = [_1: bigint, _2: vector]

Solution

  • Spark API changed between 1.x and 2.x branch. In particular DataFrame.map returns Dataset not an RDD so the result is not compatible with old MLlib RDD-based API. You should convert data to RDD first as followed :

    import org.apache.spark.mllib.linalg.Vectors
    import org.apache.spark.sql.Row
    import org.apache.spark.mllib.linalg.Vector
    import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
    
    val a = Vectors.dense(Array(1.0, 2.0, 3.0))
    val b = Vectors.dense(Array(3.0, 4.0, 5.0))
    val df = Seq((1L ,a), (2L, b), (2L, a)).toDF
    
    val ldaDF = df.rdd.map { 
      case Row(id: Long, countVector: Vector) => (id, countVector) 
    } 
    
    val model = new LDA().setK(3).run(ldaDF)
    

    or you can convert to typed dataset and then to RDD:

    val model = new LDA().setK(3).run(df.as[(Long, Vector)].rdd)