Search code examples
scalaapache-sparkrddapache-spark-mllibapache-spark-ml

How to convert from org.apache.spark.mllib.linalg.SparseVector to org.apache.spark.ml.linalg.SparseVector?


How to convert from org.apache.spark.mllib.linalg.SparseVector to org.apache.spark.ml.linalg.SparseVector?

I am converting the code from from mllib to the ml api.

import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.ml.linalg.{DenseVector => NewDenseVector, Vector => NewVector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.feature.{LabeledPoint => NewLabeledPoint}

val labelPointData = limitedTable.rdd.map { row =>
  new NewLabeledPoint(convertToDouble(row.head), row(1).asInstanceOf[org.apache.spark.ml.linalg.SparseVector])
}

statement row(1).asInstanceOf[org.apache.spark.ml.linalg.SparseVector] is not working because of the following exception:

org.apache.spark.mllib.linalg.SparseVector cannot be cast to org.apache.spark.ml.linalg.SparseVector

How to overcome that?

I have found code converting from the mllib to ml but not viceversa.


Solution

  • It is possible to convert in both directions. First, let's create an mllib SparseVector:

    import org.apache.spark.mllib.linalg.Vectors
    val mllibVec: org.apache.spark.mllib.linalg.Vector = Vectors.sparse(3, Array(1,2,3), Array(1,2,3))
    

    To convert to ML SparseVector, simply use asML:

    val mlVec: org.apache.spark.ml.linalg.Vector = mllibVec.asML
    

    To convert it back again, the easiest way is to use Vectors.fromML():

    val mllibVec2: org.apache.spark.mllib.linalg.Vector = Vectors.fromML(mlVec)
    

    In addition, in your code, instead of row(1).asInstanceOf[SparseVector] you could try row.getAs[SparseVector](1). Try reading the vector as a mllib vector, then convert it with asML and pass into the ML-based LabeledPoint, i.e.:

    val labelPointData = limitedTable.rdd.map { row =>
      NewLabeledPoint(convertToDouble(row.head), row.getAs[org.apache.spark.mllb.linalg.SparseVector](1).asML)
    }