Search code examples
pythonfunctionloopsapache-sparkpyspark

Recasting column types with a function and a dictionary in pyspark


I have a large dataset with many columns. I want to write a function, using pyspark, which does the following:

-> I define a dictionary with columnnames (values) and datatypes (keys) -> Look up the columns from the dataframe in the dictionary -> If a column is found in a category, use the key of that category to cast a columntype

I got the function so far that it does not throw an error, Unfortunatelly, the function does not change the column types. I could not find the mistake. Can someone maybe spot the problem? Thanks :)

import pyspark
from pyspark.sql.types import StringType, IntegerType, ArrayType

# This creates a sample dataframe
simpleData = [("James", "Sales", 3000),
    ("Michael", "Sales", 4600),
    ("Robert", "Sales", 4100),
    ("Kumar", "Marketing", 2000),
    ("Saif", "Sales", 4100)]
schema = ["employee_name", "department", "salary"]
table = spark.createDataFrame(data=simpleData, schema=schema)

# This is the function which is supposed to change datatypes en bulk
def recasting_function(data):
    df = data

    column_types = {
        "StringType()": ["employee_name", "department"],
        "IntegerType()": ["salary"]
        }

    for column in df.columns:
        if column in column_types.items():
            df = df.withColumn(item, df.item.cast(key))
    
    return df
        
# Here I apply it to my sample dataset
result = recasting_function(table)

Solution

  • Notice that item and key in your code are undefined variables, but it doesn't throw any error because the if clause is always False. Try the code below:

    # This creates a sample dataframe
    simpleData = [("James", "Sales", 3000),
        ("Michael", "Sales", 4600),
        ("Robert", "Sales", 4100),
        ("Kumar", "Marketing", 2000),
        ("Saif", "Sales", 4100)]
    schema = ["employee_name", "department", "salary"]
    table = spark.createDataFrame(data=simpleData, schema=schema)
    
    # This is the function which is supposed to change datatypes en bulk
    def recasting_function(data):
        column_types = {
            "string": ["employee_name", "department"],
            "int": ["salary"]
            }
        for (k, v) in column_types.items():
            for c in v:
                if c in data.columns:
                    data = data.withColumn(c, data[c].cast(k))    
        return data
            
    # Here I apply it to my sample dataset
    result = recasting_function(table)