Search code examples
pythonpandasapache-sparkdatabricks

Converting all columns in spark df from decimal to float for pandas conversion


I want to create a function to transform the datatype of all spark dataframe columns from decimal to float.
I do not know my column names in advance, nor if and how many columns of the type of decimal are included. This excludes explicit casting of columns to prevent scaling limitations.
Other data type columns should not be affected.
NULLS might occur.

Reason behind all this madness: I need to convert the spark dataframe to pandas, to then be able to write an xlsx file. The transformation to pandas of decimal however results in an object type, which is stored in the xlsx file as text, not as a number.

Sample code:

df = spark.sql("select 'text' as txt, 1.1111 as one, 2.22222 as two, CAST(3.333333333333 AS FLOAT) as three")
df.printSchema()

>>
root
 |-- txt: string (nullable = false)
 |-- one: decimal(5,4) (nullable = false)
 |-- two: decimal(6,5) (nullable = false)
 |-- three: float (nullable = false)

Transform to Pandas:

df_pd = df.toPandas()
print(df_pd.dtypes)

>>
txt       object
one       object
two       object
three    float32
dtype: object

I need all of the decimal types to be of float type in df_pd.


Ideally I have something like this:

df = spark.sql("select 'text' as txt, 1.1111 as one, 2.22222 as two, 3.333333333333 as three")

insert magic

df.printSchema()

>>
root
 |-- txt: string (nullable = false)
 |-- one: float (nullable = false)
 |-- two: float (nullable = false)
 |-- three: float (nullable = false)

Thanks


Solution

  • To resolve your issue, please follow below code. For the sample, I use the above four columns as a data frame and convert it into a Temp table.

    Code:

    from pyspark.sql.functions import col
    from pyspark.sql.types import DecimalType, FloatType
    
    df1 = spark.sql("""
        SELECT 'text' AS txt, 
               CAST(1.1111 AS DECIMAL(5,4)) AS one, 
               CAST(2.22222 AS DECIMAL(6,5)) AS two, 
               CAST(3.333333333333 AS FLOAT) AS three
    """)
    
    df1.createOrReplaceTempView("deci_table")
    
    def convert_decimal_to_float_from_table(table_name):
        df12 = spark.sql(f"SELECT * FROM {table_name}")
    
        # check the decimal columns values
        decimal_columns = [field.name for field in df12.schema.fields if isinstance(field.dataType, DecimalType)]
    
        # using below for loop you can convert decimal columns to float 
        for col_name in decimal_columns:
            df12 = df12.withColumn(col_name, col(col_name).cast(FloatType()))
    
        return df12
    
    df1.printSchema()
    df12_conv = convert_decimal_to_float_from_table("deci_table")
    
    df12_conv.printSchema()
    
    df_pd = df12_conv.toPandas()
    print(df_pd.dtypes)
    

    Output:

    enter image description here