Search code examples
pythondataframeplotlycrosstab

Plotly Bubble chart from pandas crosstab


How can I plot a bubble chart from a dataframe that has been created from a pandas crosstab of another dataframe?

Imports;

import plotly as py
import plotly.graph_objects as go
from plotly.subplots import make_subplots

The crosstab was created using;

df = pd.crosstab(raw_data['Speed'], raw_data['Height'].fillna('n/a'))

The df contains mostly zeros, however where a number appears I want a point where the value controls the point size. I want to set the Index values as the x axis and the columns name values as the Y axis.

The df would look something like;

         10    20    30    40    50
1000     0     0    0      0     5
1100     0     0    0      7     0
1200     1     0    3      0     0
1300     0     0    0      0     0
1400     5     0    0      0     0

I’ve tried using scatter & Scatter like this;

fig.add_trace(go.Scatter(x=df.index.values, y=df.columns.values, size=df.values,
                         mode='lines'),
              row=1, col=3)

This returned a TypeError: 'Module' object not callable.

Any help is really appreciatted. Thanks

UPDATE

The answers below are close to what I ended up with, main difference being that I reference 'Speed' in the melt line;

df.reset_index()
df.melt(id_vars="Speed")
df.rename(columns={"index":"Engine Speed",
                    "variable":"Height",
                    "value":"Count"})
df[df!=0].dropna()

scale=1000

fig.add_trace(go.Scatter(x=df["Speed"], y=df["Height"],mode='markers',marker_size=df["Count"]/scale),
              row=1, col=3)

This works however my main problem now is that the dataset is huge and plotly is really struggling to deal with it.

Update 2

Using Scattergl allows Plotly to deal with the large dataset very well!


Solution

  • If this is the case you can use plotly.express this is very similar to @Erik answer but shouldn't return errors.

    import pandas as pd
    import plotly.express as px
    from io import StringIO
    
    txt = """
            10    20    30    40    50
    1000     0     0    0      0     5
    1100     0     0    0      7     0
    1200     1     0    3      0     0
    1300     0     0    0      0     0
    1400     5     0    0      0     0
    """
    
    df = pd.read_csv(StringIO(txt), delim_whitespace=True)
    
    df = df.reset_index()\
           .melt(id_vars="index")\
           .rename(columns={"index":"Speed",
                            "variable":"Height",
                            "value":"Count"})
    
    fig = px.scatter(df, x="Speed", y="Height",size="Count")
    fig.show()
    

    enter image description here

    UPDATE In case you got error please check your pandas version with pd.__version__ and try to check line by line this

    df = pd.read_csv(StringIO(txt), delim_whitespace=True)
    
    df = df.reset_index()
    
    df = df.melt(id_vars="index")
    
    df = df.rename(columns={"index":"Speed",
                            "variable":"Height",
                            "value":"Count"})
    

    and report in which line it breaks.