Hi I have a simple 3D scatter plot - a dataframe bm
with columns and index as the x
and y
axes. When I plot it, I want to add a colour map - also simple and I've done it below.
However, in my data bm
I have some zeros which I do not want to plot - this is also easy - I set them to NaN
. However, this causes a problem with the colour matrix. scatter
does not like this. I've tried both passing colours matrix with nans and without nans and they both fail with an error.
The code below is fully functional if you remove the line bm = bm.replace({0: np.nan})
it will plot.
N = 100
bm = pd.DataFrame(
index=pd.bdate_range(start='2012-01-01', periods=N, freq='B'),
data={x: np.random.randn(N) for x in range(1, 11)}
)
# Simulate some zeros
bm = pd.DataFrame(index=bm.index, columns=bm.columns, data=np.where(np.abs(bm.values) < 0.02, 0, bm.values))
# Set zeros to Nan so that I don't plot them
bm = bm.replace({0: np.nan})
z = bm.values
x = bm.columns.tolist()
y = bm.reset_index().index.tolist()
x, y = np.meshgrid(x, y)
# Set up plot
fig = plt.figure(figsize = (15,10))
ax = plt.axes(projection ='3d')
# plotting
ax.scatter(x, y, z, '.', c=bm.values, cmap='Reds') # THIS FAILS
ax.xaxis.set_ticklabels(bm.columns);
ax.yaxis.set_ticklabels(bm.index.strftime('%Y-%m-%d'));
Any help will be welcome
Not 100% sure why this is failing but I guess it might have something to do with the fact that c
is wrongly identified as RGB/RGBA array due to it's 2D nature?
From the docs:
c: color, sequence, or sequence of colors, optional The marker color. Possible values:
- A single color format string.
- A sequence of colors of length n.
- A sequence of n numbers to be mapped to colors using cmap and norm.
- A 2D array in which the rows are RGB or RGBA.
If you convert your data and coordinates to 1D prior to plotting, scatter seems to handle the nan's just fine...
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
N = 100
bm = pd.DataFrame(
index=pd.bdate_range(start='2012-01-01', periods=N, freq='B'),
data={x: np.random.randn(N) for x in range(1, 11)}
)
# Simulate some zeros
bm = pd.DataFrame(index=bm.index, columns=bm.columns, data=np.where(np.abs(bm.values) < 0.02, 0, bm.values))
# Set zeros to Nan so that I don't plot them
bm = bm.replace({0: np.nan})
# unstack dataframe
flat_bm = bm.reset_index(drop=True).unstack()
x = flat_bm.index.get_level_values(0)
y = flat_bm.index.get_level_values(1)
z = flat_bm.values
# Set up plot
fig = plt.figure(figsize = (15,10))
ax = plt.axes(projection ='3d')
# plotting
ax.scatter(x, y, z, '.', c=flat_bm.values, cmap='Reds')
ax.xaxis.set_ticklabels(bm.columns);
ax.yaxis.set_ticklabels(bm.index.strftime('%Y-%m-%d'));