Search code examples
pythonapache-sparkdataframepysparkrename

Rename nested field in spark dataframe


Having a dataframe df in Spark:

 |-- array_field: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- a: string (nullable = true)
 |    |    |-- b: long (nullable = true)
 |    |    |-- c: long (nullable = true)

How to rename field array_field.a to array_field.a_renamed?

[Update]:

.withColumnRenamed() does not work with nested fields so I tried this hacky and unsafe method:

# First alter the schema:
schema = df.schema
schema['array_field'].dataType.elementType['a'].name = 'a_renamed'

ind = schema['array_field'].dataType.elementType.names.index('a')
schema['array_field'].dataType.elementType.names[ind] = 'a_renamed'

# Then set dataframe's schema with altered schema
df._schema = schema

I know that setting a private attribute is not a good practice but I don't know other way to set the schema for df

I think I am on a right track but df.printSchema() still shows the old name for array_field.a, though df.schema == schema is True


Solution

  • Python

    It is not possible to modify a single nested field. You have to recreate a whole structure. In this particular case the simplest solution is to use cast.

    First a bunch of imports:

    from collections import namedtuple
    from pyspark.sql.functions import col
    from pyspark.sql.types import (
        ArrayType, LongType, StringType, StructField, StructType)
    

    and example data:

    Record = namedtuple("Record", ["a", "b", "c"])
    
    df = sc.parallelize([([Record("foo", 1, 3)], )]).toDF(["array_field"])
    

    Let's confirm that the schema is the same as in your case:

    df.printSchema()
    
    root
     |-- array_field: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- a: string (nullable = true)
     |    |    |-- b: long (nullable = true)
     |    |    |-- c: long (nullable = true)
    

    You can define a new schema for example as a string:

    str_schema = "array<struct<a_renamed:string,b:bigint,c:bigint>>"
    
    df.select(col("array_field").cast(str_schema)).printSchema()
    
    root
     |-- array_field: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- a_renamed: string (nullable = true)
     |    |    |-- b: long (nullable = true)
     |    |    |-- c: long (nullable = true)
    

    or a DataType:

    struct_schema = ArrayType(StructType([
        StructField("a_renamed", StringType()),
        StructField("b", LongType()),
        StructField("c", LongType())
    ]))
    
     df.select(col("array_field").cast(struct_schema)).printSchema()
    
    root
     |-- array_field: array (nullable = true)
     |    |-- element: struct (containsNull = true)
     |    |    |-- a_renamed: string (nullable = true)
     |    |    |-- b: long (nullable = true)
     |    |    |-- c: long (nullable = true)
    

    Scala

    The same techniques can be used in Scala:

    case class Record(a: String, b: Long, c: Long)
    
    val df = Seq(Tuple1(Seq(Record("foo", 1, 3)))).toDF("array_field")
    
    val strSchema = "array<struct<a_renamed:string,b:bigint,c:bigint>>"
    
    df.select($"array_field".cast(strSchema))
    

    or

    import org.apache.spark.sql.types._
    
    val structSchema = ArrayType(StructType(Seq(
        StructField("a_renamed", StringType),
        StructField("b", LongType),
        StructField("c", LongType)
    )))
    
    df.select($"array_field".cast(structSchema))
    

    Possible improvements:

    If you use an expressive data manipulation or JSON processing library it could be easier to dump data types to dict or JSON string and take it from there for example (Python / toolz):

    from toolz.curried import pipe, assoc_in, update_in, map
    from operator import attrgetter
    
    # Update name to "a_updated" if name is "a"
    rename_field = update_in(
        keys=["name"], func=lambda x: "a_updated" if x == "a" else x)
    
    updated_schema = pipe(
       #  Get schema of the field as a dict
       df.schema["array_field"].jsonValue(),
       # Update fields with rename
       update_in(
           keys=["type", "elementType", "fields"],
           func=lambda x: pipe(x, map(rename_field), list)),
       # Load schema from dict
       StructField.fromJson,
       # Get data type
       attrgetter("dataType"))
    
    df.select(col("array_field").cast(updated_schema)).printSchema()