Search code examples
pythonapache-sparkdatepysparkmultiple-columns

How to define a generic function to derive multiple new columns from a given column name?


I have been trying to derive two columns from a given column name through a python function. Below is the code snippet:

from pyspark.sql.functions import substring

def deriveCol(source_col_name, col1, col2):
    df.select(source_col_name, substring(source_col_name, 1, 4).alias(col1), substring(source_col_name, 5, 2).alias(col2))
    for i in col2:
        if i >= "01" and i <= "03":
            print("First quarter")
        elif i >= "04" and i <= "06":
            print("second quarter")
        elif i >= "06" and i <= "09":
            print("Third quarter")
        else:
            print("Fourth quarter")
    return df.select([col1, col2]).show(10, truncate=True)
 
t = deriveCol("Report", "year", "month")

Please see the below output

Fourth quarter
Fourth quarter
+----+-----+
|year|month|
+----+-----+
|2022|   01|
|2022|   01|
|2022|   01|
|2022|   01|
|2022|   01|
|2022|   01|
|2022|   01|
|2022|   01|
|2022|   01|
|2022|   01|
+----+-----+
only showing top 10 rows

"Report' = the source column from where data has to be derived from.
"year"/"month" = derived column.

I have a column coming for a database table, and I have to split it into 3 parts: "year", "month", "quarter".

The "year" and "month" were easy. But finding the "quarter" is not working. When executing the code, it's directly going to else and printing "fourth quarter". But I can say that the "month" which I have derived above with the help of substring() function also has dates coming from all months of a given year.

Note: I need to save the output of the "quarter" in a separate column.


Solution

  • As the first step I have created a true date column out of string. Having the date column we can apply methods year, month and quarter.

    def derive_col(source_col_name, col1, col2, col3):
        date_col = F.to_date(source_col_name, 'yyyyMMdd')
        df = df.select(
            *[c for c in df.columns if c != source_col_name],
            F.year(date_col).alias(col1),
            F.month(date_col).alias(col2),
            F.quarter(date_col).alias(col3)
        )
        return df
    

    Full test, using one more df parameter, as I cannot use the function without it:

    from pyspark.sql import functions as F
    
    df = spark.createDataFrame([('20220101',), ('20220731',)], ['Report'])
    
    def derive_col(df, source_col_name, col1, col2, col3):
        date_col = F.to_date(source_col_name, 'yyyyMMdd')
        df = df.select(
            *[c for c in df.columns if c != source_col_name],
            F.year(date_col).alias(col1),
            F.month(date_col).alias(col2),
            F.quarter(date_col).alias(col3)
        )
        return df
    
    df = derive_col(df, 'Report', 'year', 'month', 'quarter')
    
    df.show()
    # +----+-----+-------+
    # |year|month|quarter|
    # +----+-----+-------+
    # |2022|    1|      1|
    # |2022|    7|      3|
    # +----+-----+-------+