Search code examples
xmlapache-sparkpysparkapache-spark-sqluser-defined-functions

Parsing the nested XML fields from PySpark Dataframe using UDF


I have a scenario where I have XML data in a dataframe column.

sex updated_at visitors
F 1574264158 <?xml version="1.0" encoding="utf-8

I want to parse - Visitors column - the nested XML fields into columns in Dataframe using UDF

Format of XML

<?xml version="1.0" encoding="utf-8"?> <visitors> <visitor id="9615" age="68" sex="F" /> <visitor id="1882" age="34" sex="M" /> <visitor id="5987" age="23" sex="M" /> </visitors>

Solution

  • You can use xpath queries without using UDFs:

    df = spark.createDataFrame([['<?xml version="1.0" encoding="utf-8"?> <visitors> <visitor id="9615" age="68" sex="F" /> <visitor id="1882" age="34" sex="M" /> <visitor id="5987" age="23" sex="M" /> </visitors>']], ['visitors'])
    
    df.show(truncate=False)
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    |visitors                                                                                                                                                                          |
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    |<?xml version="1.0" encoding="utf-8"?> <visitors> <visitor id="9615" age="68" sex="F" /> <visitor id="1882" age="34" sex="M" /> <visitor id="5987" age="23" sex="M" /> </visitors>|
    +----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
    
    
    df2 = df.selectExpr(
        "xpath(visitors, './visitors/visitor/@id') id",
        "xpath(visitors, './visitors/visitor/@age') age",
        "xpath(visitors, './visitors/visitor/@sex') sex"
    ).selectExpr(
        "explode(arrays_zip(id, age, sex)) visitors"
    ).select('visitors.*')
    
    df2.show(truncate=False)
    +----+---+---+
    |id  |age|sex|
    +----+---+---+
    |9615|68 |F  |
    |1882|34 |M  |
    |5987|23 |M  |
    +----+---+---+
    

    If you insist on using UDFs:

    import xml.etree.ElementTree as ET
    import pyspark.sql.functions as F
    
    @F.udf('array<struct<id:string, age:string, sex:string>>')
    def parse_xml(s):
        root = ET.fromstring(s)
        return list(map(lambda x: x.attrib, root.findall('visitor')))
        
    df2 = df.select(
        F.explode(parse_xml('visitors')).alias('visitors')
    ).select('visitors.*')
    
    df2.show()
    +----+---+---+
    |  id|age|sex|
    +----+---+---+
    |9615| 68|  F|
    |1882| 34|  M|
    |5987| 23|  M|
    +----+---+---+