Search code examples
pythonpandassankey-diagram

Python: transforming complex data for a Sankey plot


I am trying to produce a Sankey plot of the events that take one week before and one week after an index event of a patient. Imagine I have the following data frame:

df = 
patient_id  start_date   end_date    Value Index_event_date Value_Index_event
1           28-12-1999  02-01-2000   A     01-01-2000       X
2           28-12-2000  02-12-2001   B     01-01-2001       X                   
3           28-12-2001  02-01-2002   A     01-01-2002       X

I would like to group into "codes" the above data frame. For example, one week before the index event is code1, the week of the index event is code2, and the next week after the index event is code3.

The resulting data frame would be:

patient_id code1 code2 code3
1          A     X     A
2          B     X     Na
3          A     X     A

In the above example all patients except for patient 2 have observations in both weeks (one before and one after the index event). In the case of patient 2, it has only an observation in the week before the index event and that is why for code3 (week after the index event), we see an Na.


Solution

  • With the dataframe you provided:

    import pandas as pd
    
    df = pd.DataFrame(
        {
            "patient_id": [1, 2, 3],
            "start_date": ["28-12-1999", "28-12-2000", "28-12-2001"],
            "end_date": ["02-01-2000", "02-12-2001", "02-01-2002"],
            "Value": ["A", "B", "A"],
            "Index_event_date": ["01-01-2000", "01-01-2001", "01-01-2002"],
            "Value_Index_event": ["X", "X", "X"],
        }
    )
    

    Here is one way to do it with Pandas to_datetime and DateOffset (assuming that, by week, you mean 7 days before/after):

    # Setup
    for col in ["start_date", "end_date", "Index_event_date"]:
        df[col] = pd.to_datetime(df[col], format="%d-%m-%Y")
    
    # Add new columns
    df["code1"] = df.apply(
        lambda x: x["Value"]
        if x["start_date"] >= (x["Index_event_date"] - pd.DateOffset(days=7))
        else None,
        axis=1,
    )
    df["code2"] = df["Value_Index_event"]
    df["code3"] = df.apply(
        lambda x: x["Value"]
        if x["end_date"] <= (x["Index_event_date"] + pd.DateOffset(days=7))
        else None,
        axis=1,
    )
    
    # Cleanup
    df = df.drop(
        columns=["start_date", "end_date", "Value", "Index_event_date", "Value_Index_event"]
    )
    

    Then:

    print(df)
    # Output
        patient_id  code1   code2   code3
    0            1      A       X       A
    1            2      B       X    None
    2            3      A       X       A