Search code examples
pythonsqlpandasdataframeanalytics

How do i write this in python and preferably in pandas?(Assume that i am dealing with a dataframe)


This is the code that i am trying to convert to Pandas:

select 
geo,
region,
sum(case when year(txn_date)>=2020 then revenue else 0 end) as ytd_rev,
sum(case when year(txn_date)=2019 then revenue else 0 end) as py_ytd_rev,
sum(profit) as total_profit
from table
group by 1,2

Assume the following column for input data frame: geo | region | sub region | txn_date | revenue | profit.

Columns in output dataframe : geo | region | ytd_rev | py_ytd_rev| total_profit


Solution

  • I believe you need GroupBy.agg with named aggregation and new columns created in DataFrame.assign:

    Creating sample dataframe:

    import pandas as pd
    from datetime import datetime as dt
    
    df = pd.DataFrame(columns=['geo','region','sub region', 'txn_date', 'revenue', 'profit'])
    
    df.loc[len(df.index)] = ['G1', 'R1', 'SR1', dt.strptime('23Sep19', '%d%b%y'), 1000, 200]
    df.loc[len(df.index)] = ['G2', 'R1', 'SR1', dt.strptime('10Sep20', '%d%b%y'), 3000, 100]
    df.loc[len(df.index)] = ['G5', 'R2', 'SR1', dt.strptime('11Sep19', '%d%b%y'), 4000, 150]
    df.loc[len(df.index)] = ['G4', 'R2', 'SR2', dt.strptime('15Sep18', '%d%b%y'), 1500, 300]
    df.loc[len(df.index)] = ['G3', 'R1', 'SR1', dt.strptime('30Sep20', '%d%b%y'), 800, -50]
    df.loc[len(df.index)] = ['G6', 'R3', 'SR1', dt.strptime('01Sep19', '%d%b%y'), 3000, 100]
    
    print(df)
    

    The sample dataframe:

      geo region sub region   txn_date revenue profit
    0  G1     R1        SR1 2019-09-23    1000    200
    1  G2     R1        SR1 2020-09-10    3000    100
    2  G5     R2        SR1 2019-09-11    4000    150
    3  G4     R2        SR2 2018-09-15    1500    300
    4  G3     R1        SR1 2020-09-30     800    -50
    5  G6     R3        SR1 2019-09-01    3000    100
    

    The solution to the problem:

    df['txn_date'] = pd.to_datetime(df['txn_date'])
        
    df = (df.assign(ytd_rev = df['revenue'].where(df['txn_date'].dt.year >=2020, 0),
                        py_ytd_rev = df['revenue'].where(df['txn_date'].dt.year ==2019, 0))
                .groupby(['geo','region'])
                .agg(ytd_rev = ('ytd_rev','sum'),
                     py_ytd_rev = ('py_ytd_rev','sum'),
                     total_profit = ('profit','sum'))
                .reset_index())
    
    print(df)
    

    The final output:

      geo region  ytd_rev  py_ytd_rev  total_profit
    0  G1     R1        0        1000           200
    1  G2     R1     3000           0           100
    2  G3     R1      800           0           -50
    3  G4     R2        0           0           300
    4  G5     R2        0        4000           150
    5  G6     R3        0        3000           100