Search code examples
pythonapache-sparkpysparkapache-spark-mllibapache-spark-ml

Create a dataframe with SparseVector PySpark


Let's say I have a Spark dataframe that looks like this

Row(Y=a, X1=3.2, X2=4.5)

What I'd want is:

Row(Y=a, features=SparseVector(2, {X1: 3.2, X2: 4.5})

Solution

  • Perhaps this is helpful-

    Written in scala but can be implemented in pyspark with minimal change

    VectorAssembler to create vector from input columns

    val df = spark.sql("select 'a' as Y, 3.2 as X1, 4.5 as X2")
        df.show(false)
        df.printSchema()
    
        /**
          * +---+---+---+
          * |Y  |X1 |X2 |
          * +---+---+---+
          * |a  |3.2|4.5|
          * +---+---+---+
          *
          * root
          * |-- Y: string (nullable = false)
          * |-- X1: decimal(2,1) (nullable = false)
          * |-- X2: decimal(2,1) (nullable = false)
          */
        import org.apache.spark.ml.feature.VectorAssembler
        val features = new VectorAssembler()
          .setInputCols(Array("X1", "X2"))
          .setOutputCol("features")
          .transform(df)
        features.show(false)
        features.printSchema()
    
        /**
          * +---+---+---+---------+
          * |Y  |X1 |X2 |features |
          * +---+---+---+---------+
          * |a  |3.2|4.5|[3.2,4.5]|
          * +---+---+---+---------+
          *
          * root
          * |-- Y: string (nullable = false)
          * |-- X1: decimal(2,1) (nullable = false)
          * |-- X2: decimal(2,1) (nullable = false)
          * |-- features: vector (nullable = true)
          */