Search code examples
dataframescaladictionaryapache-spark

Update Spark dataframe column names based on Map type key value pairs


I have a spark dataframe df. I need to update Spark dataframe column names based on Map type key value pairs.

 df.show()

   | col1|col2 |col3|
   |  2  |  Ive|1989|
   |Tom  | null|1981|
   |  4  | John|1991|

 Map_value = (col1 -> id, col2 -> name, col3 -> year)

Need help. I am not sure how to proceed

Expected output:

   | id  | name|year|
   |  2  |  Ive|1989|
   |Tom  | null|1981|
   |  4  | John|1991|
  

Solution

  • Given:

    case class ColData(col1: String, col2: String, col3: Int)
    

    defined at a top level:

        val sourceSeq = Seq(
          ColData("2", "Ive", 1989),
          ColData("Tom", null, 1981),
          ColData("4", "John", 1991),
        )
    
        import sparkSession.implicits._
    
        def mapFields[T](ds: Dataset[T], fieldNameMap: Map[String, String]): DataFrame = {
          // make sure the fields are present - note this is not a free operation
          val fieldNames = ds.schema.fieldNames.toSet
          val newNames = fieldNameMap.filterKeys(fieldNames).map{ 
            case (oldFieldName, newFieldName) => col(oldFieldName).as(newFieldName)
          }.toSeq
          
          ds.select(newNames: _*)
        }
    
        val newNames = mapFields(sourceSeq.toDS(), Map("col1" -> "id", "col2" -> "name", "col3" -> "year", "not a field" -> "field"))
    
        newNames.show()
    

    yielding:

    +---+----+----+
    | id|name|year|
    +---+----+----+
    |  2| Ive|1989|
    |Tom|null|1981|
    |  4|John|1991|
    +---+----+----+
    

    Note:

    The fieldNames check uses ds.schema, which can be very expensive so prefer to use known fields instead of .schema. Using withColumn or withColumn renamed over lots of fields can severely impact performance as not all the projections are actually removed in generated code, prefer to keep the number of projections low where possible.