Search code examples
pythonmatplotlib3dmplot3d

How to colour data points on a 3D scatterplot in matplotlib


Edit: I managed to figure out what was going on. Scatter has a parameter 'line-width' (lw=n) that determines the thickness of the line surrounding the plot point for a scatter plot. Because my plot points were size 1 (s=1), the line width was so thick it was actually covering the colour of the plot point. Setting the line-width to a thickness of 0 (lw=0) should do the trick.

I want to generate a 3d scatterplot of data-points, colouring them based on the value of their y-coordinate, but I can't manage to get the points to actually colour.

If the value of the datapoint is low, the colour should be closer to the blue-end of the colour spectrum. If the value is higher the, the colour should be closer to the red-end of the spectrum.

I've managed to plot what I want in 2D, but am having trouble replicating the process in 3D. The current code only plots the points in black.

Here is my code for the 3D attempt, and a screenshot of the desired results in 2D. What exactly am I doing wrong here?

x_points, y_points, and z_points are lists of float values.

import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

def three_dimensional_scatterplot(
    self, x_points, y_points, z_points, data_file
):

    cm1 = cm.get_cmap('gist_rainbow')

    fig = plt1.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(
        x_points,
        y_points,
        z_points,
        s=1,
        c=y_points,
        cmap=cm1
    )

    ax.set_xlabel('X axis')
    plt1.show()

enter image description here


Solution

  • You have to plot like here:

    import matplotlib.pyplot as plt
    from matplotlib import cm
    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    
    x = np.random.rand(25)
    y = np.random.rand(25)
    z = np.random.rand(25)
    
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    p3d = ax.scatter(x, y, z, s=30, c=y, cmap = cm.coolwarm)
    plt.show()
    

    enter image description here