Search code examples
pythonapache-sparkdictionarypysparkdictionary-comprehension

Converting pyspark dataframe into dictionary: result different than expected


Let's imagine I have the following pyspark dataframe:

data = [("USA",20,40,60),
    ("India",50,40,30),
    ("Nepal",20,50,30),
    ("Ireland",40,60,70),
    ("Norway",50,50,60)
  ]

columns = ["country", "A", "B", "C"]
 
df = spark.createDataFrame(data=data,schema=columns)

To create a dictionary from it, I followed the following approach:

import pyspark.sql.functions as F
list_test = [row.asDict() for row in df.collect()]
dict_test = {country['country']: country for country in list_test}

The result is as follows:

{'USA': {'country': 'USA', 'A': 20, 'B': 40, 'C': 60}, 'India': {'country': 'India', 'A': 50, 'B': 40, 'C': 30}, 'Nepal': {'country': 'Nepal', 'A': 20, 'B': 50, 'C': 30}, 'Ireland': {'country': 'Ireland', 'A': 40, 'B': 60, 'C': 70}, 'Norway': {'country': 'Norway', 'A': 50, 'B': 50, 'C': 60}}

However, what I wanted was the following:

{'USA': {'A': 20, 'B': 40, 'C': 60}, 'India': {'A': 50, 'B': 40, 'C': 30}, 'Nepal': {'A': 20, 'B': 50, 'C': 30}, 'Ireland': {'A': 40, 'B': 60, 'C': 70}, 'Norway': {'A': 50, 'B': 50, 'C': 60}}

How can I obtain this? I'm not sure I understand what I'm doing wrong.


Solution

  • You can do a dict comprehension to remove the unwanted item:

    list_test = [row.asDict() for row in df.collect()]
    dict_test = {country['country']: {k:v for k,v in country.items() if k != 'country'} for country in list_test}
    
    print(dict_test)
    {'USA': {'A': 20, 'B': 40, 'C': 60}, 'India': {'A': 50, 'B': 40, 'C': 30}, 'Nepal': {'A': 20, 'B': 50, 'C': 30}, 'Ireland': {'A': 40, 'B': 60, 'C': 70}, 'Norway': {'A': 50, 'B': 50, 'C': 60}}