Search code examples
pysparkapache-spark-sqlorc

access fields of an array within pyspark dataframe


I am developing sql queries to a spark dataframe that are based on a group of ORC files. The program goes like this:

from pyspark.sql import SparkSession
spark_session = SparkSession.builder.appName("test").getOrCreate()
sdf = spark_session.read.orc("../data/")
sdf.createOrReplaceTempView("test")

Now I have a table called "test". If I do something like:

spark_session.sql("select count(*) from test")

then the result will be fine. But I need to get more columns in the query, including some of the fields in array.

In [8]: sdf.take(1)[0]["person"]
Out[8]:
[Row(name='name', value='tom'),
 Row(name='age', value='20'),
 Row(name='gender', value='m')]

I have tried something like:

spark_session.sql("select person.age, count(*) from test group by person.age")

But this does not work. My question is: how to access the fields in the "person" array?

Thanks!

EDIT:

result of sdf.printSchema()

In [3]: sdf.printSchema()
root
 |-- person: integer (nullable = true)
 |-- customtags: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- name: string (nullable = true)
 |    |    |-- value: string (nullable = true)

Error messages:

AnalysisException: 'No such struct field age in name, value; line 16 pos 8'

Solution

  • I don't know how to do this using only PySpark-SQL, but here is a way to do it using PySpark DataFrames.

    Basically, we can convert the struct column into a MapType() using the create_map() function. Then we can directly access the fields using string indexing.

    Consider the following example:

    Define Schema

    schema = StructType([
            StructField('person', IntegerType()),
            StructField(
                'customtags',
                ArrayType(
                    StructType(
                        [
                            StructField('name', StringType()),
                            StructField('value', StringType())
                        ]
                    )
                )
            )
        ]
    )
    

    Create Example DataFrame

    data = [
        (
            1, 
            [
                {'name': 'name', 'value': 'tom'},
                {'name': 'age', 'value': '20'},
                {'name': 'gender', 'value': 'm'}
            ]
        ),
        (
            2,
            [
                {'name': 'name', 'value': 'jerry'},
                {'name': 'age', 'value': '20'},
                {'name': 'gender', 'value': 'm'}
            ]
        ),
        (
            3,
            [
                {'name': 'name', 'value': 'ann'},
                {'name': 'age', 'value': '20'},
                {'name': 'gender', 'value': 'f'}
            ]
        )
    ]
    df = sqlCtx.createDataFrame(data, schema)
    df.show(truncate=False)
    #+------+------------------------------------+
    #|person|customtags                          |
    #+------+------------------------------------+
    #|1     |[[name,tom], [age,20], [gender,m]]  |
    #|2     |[[name,jerry], [age,20], [gender,m]]|
    #|3     |[[name,ann], [age,20], [gender,f]]  |
    #+------+------------------------------------+
    

    Convert the struct column to a map

    from operator import add
    import pyspark.sql.functions as f
    
    df = df.withColumn(
            'customtags',
            f.create_map(
                *reduce(
                    add, 
                    [
                        [f.col('customtags')['name'][i],
                         f.col('customtags')['value'][i]] for i in range(3)
                    ]
                )
            )
        )\
        .select('person', 'customtags')
    
    df.show(truncate=False)
    #+------+------------------------------------------+
    #|person|customtags                                |
    #+------+------------------------------------------+
    #|1     |Map(name -> tom, age -> 20, gender -> m)  |
    #|2     |Map(name -> jerry, age -> 20, gender -> m)|
    #|3     |Map(name -> ann, age -> 20, gender -> f)  |
    #+------+------------------------------------------+
    

    The catch here is that you have to know apriori the length of the ArrayType() (in this case 3) as I don't know of a way to dynamically loop over it. This also assumes that the array has the same length for all rows.

    I had to use reduce(add, ...) here because create_map() expects pairs of elements in the form of (key, value).

    Group by fields in the map column

    df.groupBy((f.col('customtags')['name']).alias('name')).count().show()
    #+-----+-----+
    #| name|count|
    #+-----+-----+
    #|  ann|    1|
    #|jerry|    1|
    #|  tom|    1|
    #+-----+-----+
    
    df.groupBy((f.col('customtags')['gender']).alias('gender')).count().show()
    #+------+-----+
    #|gender|count|
    #+------+-----+
    #|     m|    2|
    #|     f|    1|
    #+------+-----+