Search code examples
pythonmatplotlibplot

matplotlib scatter plot with NaNs and a cmap colour matrix


Hi I have a simple 3D scatter plot - a dataframe bm with columns and index as the x and y axes. When I plot it, I want to add a colour map - also simple and I've done it below.

However, in my data bm I have some zeros which I do not want to plot - this is also easy - I set them to NaN. However, this causes a problem with the colour matrix. scatter does not like this. I've tried both passing colours matrix with nans and without nans and they both fail with an error.

The code below is fully functional if you remove the line bm = bm.replace({0: np.nan}) it will plot.

N = 100

bm = pd.DataFrame(
    index=pd.bdate_range(start='2012-01-01', periods=N, freq='B'), 
    data={x: np.random.randn(N) for x in range(1, 11)}
)

# Simulate some zeros
bm = pd.DataFrame(index=bm.index, columns=bm.columns, data=np.where(np.abs(bm.values) < 0.02, 0, bm.values))

# Set zeros to Nan so that I don't plot them
bm = bm.replace({0: np.nan})

z = bm.values
x = bm.columns.tolist()
y = bm.reset_index().index.tolist()
x, y = np.meshgrid(x, y)

# Set up plot
fig = plt.figure(figsize = (15,10))
ax = plt.axes(projection ='3d')

# plotting
ax.scatter(x, y, z, '.', c=bm.values, cmap='Reds')  # THIS FAILS

ax.xaxis.set_ticklabels(bm.columns);
ax.yaxis.set_ticklabels(bm.index.strftime('%Y-%m-%d'));

Any help will be welcome


Solution

  • Not 100% sure why this is failing but I guess it might have something to do with the fact that c is wrongly identified as RGB/RGBA array due to it's 2D nature?

    From the docs:

    c: color, sequence, or sequence of colors, optional The marker color. Possible values:

    • A single color format string.
    • A sequence of colors of length n.
    • A sequence of n numbers to be mapped to colors using cmap and norm.
    • A 2D array in which the rows are RGB or RGBA.

    If you convert your data and coordinates to 1D prior to plotting, scatter seems to handle the nan's just fine...

    import pandas as pd
    import matplotlib.pyplot as plt
    import numpy as np
    
    
    N = 100
    
    bm = pd.DataFrame(
        index=pd.bdate_range(start='2012-01-01', periods=N, freq='B'), 
        data={x: np.random.randn(N) for x in range(1, 11)}
    )
    
    # Simulate some zeros
    bm = pd.DataFrame(index=bm.index, columns=bm.columns, data=np.where(np.abs(bm.values) < 0.02, 0, bm.values))
    
    # Set zeros to Nan so that I don't plot them
    bm = bm.replace({0: np.nan})
    
    # unstack dataframe
    flat_bm = bm.reset_index(drop=True).unstack()
    x = flat_bm.index.get_level_values(0)
    y = flat_bm.index.get_level_values(1)
    z = flat_bm.values
    
    
    # Set up plot
    fig = plt.figure(figsize = (15,10))
    ax = plt.axes(projection ='3d')
    
    # plotting
    ax.scatter(x, y, z, '.', c=flat_bm.values, cmap='Reds')
    
    ax.xaxis.set_ticklabels(bm.columns);
    ax.yaxis.set_ticklabels(bm.index.strftime('%Y-%m-%d'));