Search code examples
pythonpandasdataframenantrim

Trim Group by Column/Series Sequence in Pandas by NaN Ocurrence


I have a data frame as follows:

user_id metric_date metric_val1 is_churn
3 2021-01 NaN True
3 2021-02 NaN True
3 2021-03 0.4 False
3 2021-04 0.5 False
3 2021-05 NaN True
4 2021-01 0.1 False
4 2021-02 0.3 False
4 2021-03 0.2 False
4 2021-04 NaN True
4 2021-05 NaN True

Suppose there are other metric columns, but the main reference is metric_val1, how can I grouping by user_id and trim all row that have NaN value before the first valid metric_val1, and keeping only the last NaN after the last valid value for metric_val1, the output should be something like that (Assume that there is no gap in the valid values) :

user_id metric_date metric_val1 is_churn
3 2021-03 0.4 False
3 2021-04 0.5 False
3 2021-05 NaN True
4 2021-01 0.1 False
4 2021-02 0.3 False
4 2021-03 0.2 False
4 2021-04 NaN True

Can someone help me with an efficient way to do that in pandas?


Solution

  • Please boolean select all non Non values or NaN values which immediately follow non Nan values in a group and mask. Code below;

    df[df.groupby('user_id')['metric_val1'].apply(lambda x : x.notna()|x.isna()&x.shift(1).notna())]
    
    
    
        user_id metric_date  metric_val1  is_churn
    2        3     2021-03          0.4     False
    3        3     2021-04          0.5     False
    4        3     2021-05          NaN      True
    5        4     2021-01          0.1     False
    6        4     2021-02          0.3     False
    7        4     2021-03          0.2     False
    8        4     2021-04          NaN      True
    

    If you have a large dataframe and are worried of memory and speed. Could try use pyspark. Just instantiate a pyspark session. Pyspark is scalable;

    from  pyspark.sql.functions import *
    import pyspark.sql.functions as F
    from pyspark.sql import Window
    k =Window.partitionBy('user_id').orderBy('user_id','metric_date')
    (
      df.withColumn('t', lag('metric_val1').over(k))#Introduce column t which draws immediate preceding columns' value
      .filter((F.col('t')=='NaN')|(F.col('metric_val1')!='NaN'))#Filter out t is NaN or metric_val1 is not NaN
      .drop('t')#drop the temp column
    ).show()
    
    +-------+-----------+-----------+--------+
    |user_id|metric_date|metric_val1|is_churn|
    +-------+-----------+-----------+--------+
    |      3|    2021-02|        NaN|    true|
    |      3|    2021-03|        0.4|   false|
    |      3|    2021-04|        0.5|   false|
    |      4|    2021-01|        0.1|   false|
    |      4|    2021-02|        0.3|   false|
    |      4|    2021-03|        0.2|   false|
    |      4|    2021-05|        NaN|    true|
    +-------+-----------+-----------+--------+