I want to create a function to transform the datatype of all spark dataframe columns from decimal to float.
I do not know my column names in advance, nor if and how many columns of the type of decimal are included. This excludes explicit casting of columns to prevent scaling limitations.
Other data type columns should not be affected.
NULLS might occur.
Reason behind all this madness: I need to convert the spark dataframe to pandas, to then be able to write an xlsx file. The transformation to pandas of decimal however results in an object type, which is stored in the xlsx file as text, not as a number.
Sample code:
df = spark.sql("select 'text' as txt, 1.1111 as one, 2.22222 as two, CAST(3.333333333333 AS FLOAT) as three")
df.printSchema()
>>
root
|-- txt: string (nullable = false)
|-- one: decimal(5,4) (nullable = false)
|-- two: decimal(6,5) (nullable = false)
|-- three: float (nullable = false)
Transform to Pandas:
df_pd = df.toPandas()
print(df_pd.dtypes)
>>
txt object
one object
two object
three float32
dtype: object
I need all of the decimal types to be of float type in df_pd.
Ideally I have something like this:
df = spark.sql("select 'text' as txt, 1.1111 as one, 2.22222 as two, 3.333333333333 as three")
insert magic
df.printSchema()
>>
root
|-- txt: string (nullable = false)
|-- one: float (nullable = false)
|-- two: float (nullable = false)
|-- three: float (nullable = false)
Thanks
To resolve your issue, please follow below code. For the sample, I use the above four columns as a data frame and convert it into a Temp table.
Code:
from pyspark.sql.functions import col
from pyspark.sql.types import DecimalType, FloatType
df1 = spark.sql("""
SELECT 'text' AS txt,
CAST(1.1111 AS DECIMAL(5,4)) AS one,
CAST(2.22222 AS DECIMAL(6,5)) AS two,
CAST(3.333333333333 AS FLOAT) AS three
""")
df1.createOrReplaceTempView("deci_table")
def convert_decimal_to_float_from_table(table_name):
df12 = spark.sql(f"SELECT * FROM {table_name}")
# check the decimal columns values
decimal_columns = [field.name for field in df12.schema.fields if isinstance(field.dataType, DecimalType)]
# using below for loop you can convert decimal columns to float
for col_name in decimal_columns:
df12 = df12.withColumn(col_name, col(col_name).cast(FloatType()))
return df12
df1.printSchema()
df12_conv = convert_decimal_to_float_from_table("deci_table")
df12_conv.printSchema()
df_pd = df12_conv.toPandas()
print(df_pd.dtypes)
Output: