I have a large dataset with many columns. I want to write a function, using pyspark, which does the following:
-> I define a dictionary with columnnames (values) and datatypes (keys) -> Look up the columns from the dataframe in the dictionary -> If a column is found in a category, use the key of that category to cast a columntype
I got the function so far that it does not throw an error, Unfortunatelly, the function does not change the column types. I could not find the mistake. Can someone maybe spot the problem? Thanks :)
import pyspark
from pyspark.sql.types import StringType, IntegerType, ArrayType
# This creates a sample dataframe
simpleData = [("James", "Sales", 3000),
("Michael", "Sales", 4600),
("Robert", "Sales", 4100),
("Kumar", "Marketing", 2000),
("Saif", "Sales", 4100)]
schema = ["employee_name", "department", "salary"]
table = spark.createDataFrame(data=simpleData, schema=schema)
# This is the function which is supposed to change datatypes en bulk
def recasting_function(data):
df = data
column_types = {
"StringType()": ["employee_name", "department"],
"IntegerType()": ["salary"]
}
for column in df.columns:
if column in column_types.items():
df = df.withColumn(item, df.item.cast(key))
return df
# Here I apply it to my sample dataset
result = recasting_function(table)
Notice that item
and key
in your code are undefined variables, but it doesn't throw any error because the if
clause is always False. Try the code below:
# This creates a sample dataframe
simpleData = [("James", "Sales", 3000),
("Michael", "Sales", 4600),
("Robert", "Sales", 4100),
("Kumar", "Marketing", 2000),
("Saif", "Sales", 4100)]
schema = ["employee_name", "department", "salary"]
table = spark.createDataFrame(data=simpleData, schema=schema)
# This is the function which is supposed to change datatypes en bulk
def recasting_function(data):
column_types = {
"string": ["employee_name", "department"],
"int": ["salary"]
}
for (k, v) in column_types.items():
for c in v:
if c in data.columns:
data = data.withColumn(c, data[c].cast(k))
return data
# Here I apply it to my sample dataset
result = recasting_function(table)