Search code examples
pythonplotly

Cohort Chart using plotly library


I have a dataset (CSV file) and I want to build a Cohort analysis chart using plotly library. is It possible? Because I couldn't see any tutorials on it


Solution

    • as per my questions in comments Cohort Chart is not a chart type but an approach to analysis
    • for purpose of this analysis I have reduced dimensionality of dates by only considering month start
    • first part of a Cohort analysis is placing your data into cohorts. Most common approach appears to be the first time a client has been observed. Have used Date that comes from InvoiceDate
    • next part is now to look at activity of each cohort by day after they became a client. Have used pandas date capability again sticking to month starts
    • now we can calculate total amount spent by cohort and month after they became a client
    • rebase this as a percentage as this seems to be way Cohort analysis always works
    • now the simple bit - generate the plotly heatmap

    data prep and plotting

    import plotly.express as px
    
    # just month, time doesn't matter
    df["Date"] = pd.to_datetime(df["InvoiceDate"]).dt.date - pd.offsets.MonthBegin(1)
    # work out when customer was first a customer to define which cohort
    df2 = df.merge(
        df.groupby(["CustomerID"], as_index=False).agg(Cohort=("Date", "min")),
        on="CustomerID",
    )
    # months between cohort start and invoice date
    df2["Month"] = df2["Date"].dt.to_period("M").view(dtype="int64") - df2[
        "Cohort"
    ].dt.to_period("M").view(dtype="int64")
    
    df_cohort = (
        df2.groupby(["Cohort", "Month"])
        .apply(lambda d: (d["Quantity"] * d["UnitPrice"]).sum())
        .unstack("Month")
    )
    
    # rebase as percentage as per referenced example
    for c in df_cohort.columns[1:]:
        df_cohort[c] = df_cohort[c] / df_cohort[0]
    df_cohort[0] = 1
    
    # now the easy bit - generate a figure
    px.imshow(
        df_cohort, text_auto=".2%", color_continuous_scale="blues", range_color=[0, 1]
    ).update_xaxes(side="top", dtick=1).update_yaxes(dtick="M1")
    

    enter image description here

    data sourcing

    import kaggle.cli
    import sys
    import pandas as pd
    from zipfile import ZipFile
    from pathlib import Path
    import urllib
    import plotly.graph_objects as go
    
    # fmt: off
    # download data set
    url = "https://www.kaggle.com/datasets/carrie1/ecommerce-data"
    ds = urllib.parse.urlparse(url).path[1:]
    try:
        sys.argv = [sys.argv[0]] + f"datasets download {ds}".split(" ")
        kaggle.cli.main()
    except NameError:
        ds = "/".join(ds.split("/")[1:])
        sys.argv = [sys.argv[0]] + f"datasets download {ds}".split(" ")
        kaggle.cli.main()
        
    zfile = ZipFile(list(Path.cwd().glob(f"{ds.split('/')[-1]}*.zip"))[0])
    dfs = {f.filename: pd.read_csv(zfile.open(f), encoding= 'unicode_escape') for f in zfile.infolist()}
    # fmt: on
    df = dfs["data.csv"]