Search code examples
pythoncsvmatplotlibmatplotlib-3d

Assigned 3 colors to 3D plot based on Z value


I am trying to get my 3D python plot into 3 different colors based on the value of Z from a CSV file. I am trying to color a point one specific color, and then points below one color and points above one color. I can get the plot into a color above and below the point, but I can't seem to figure out how to get it into 3 colors.

I have tried to split the Z value into 3 different 3 subsets, but when I tried to plot the plot was just empty. I also tried to write it through an if statement assigning Z to the color but that did not work either. This is the code that works for the 2 color:

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import pandas as pd

headers = ['name','ra','x rads','x par','dec','y rads','Parallax','Parallax Error','central distance','Z Max','Z Min']
mergeddata = pd.read_csv(r'C:\Users\GregL\Downloads\mergedata - no neg parallax #s (2).csv')

mergeddata.z = mergeddata['central distance']
mergeddata.x = mergeddata['x par']
mergeddata.y = mergeddata['y rads']

x= mergeddata.x
y= mergeddata.y
z = mergeddata.z

colors = [z <= 1956.783590]

fig = plt.figure(figsize=(10,10))

ax = fig.add_subplot(111, projection='3d')

surf=ax.scatter3D(x,y,z,c=colors, cmap='coolwarm',s=.5,marker='^')

ax.set_title('3D Data Distance Plot')

ax.set_zlim(-100,10000)
ax.set_xlim(-50,50)
ax.set_ylim(-50,50)

ax.set_xlabel('RA')
ax.set_ylabel('DEC')
ax.set_zlabel('CENTRAL DISTANCE')

plt.show()

Which gives me this plot 3d plot


Solution

  • As mentioned by Claudio on the comment, you can create a color value and then assign a proper colormap. Here, I'm going to create a discrete color map based on Matplotlib's Tab10:

    import numpy as np
    import matplotlib.pyplot as plt
    import matplotlib.cm as cm
    import matplotlib.colors as col
    
    fig = plt.figure()
    ax = plt.axes(projection ='3d')
    
    z = np.linspace(0, 1, 100)
    x = z * np.sin(25 * z)
    y = z * np.cos(25 * z)
    
    # values for color
    c = [int(zv / 0.4) for zv in z]
    # discrete colormap with 3 colors
    cmap=col.ListedColormap(cm.tab10.colors[:len(np.unique(c))])
    ax.scatter(x, y, z, c=c, cmap=cmap)
    plt.show()
    

    enter image description here

    Alternatively, you can create multiple ax.scatter commands, each one plotting a subset. The advantage of this approach is that you can set custom labels or rendering properties to each subset:

    fig = plt.figure()
    ax = plt.axes(projection ='3d')
    
    i1 = z < 0.3
    i2 = (z >= 0.3) & (z < 0.6)
    i3 = z >= 0.6
    
    ax.scatter(x[i1], y[i1], z[i1], label="a")
    ax.scatter(x[i2], y[i2], z[i2], label="b")
    ax.scatter(x[i3], y[i3], z[i3], label="c")
    ax.legend()
    plt.show()
    

    enter image description here