Search code examples
pythonmatplotlibplotly-python

Plot with colors depending on data


I would like to plot some data but with colors depending on certain conditions. Ideally I would like to do it in both plotly and matplotlib (separate scripts)

The data

For example I have the following data

import pandas as pd

data = {
    'X': [1, 2, 3, 4, 5,6,7,8,9,10],
    'Y': [5, 4, 3, 2, 1,2,3,4,5,5],
    'XL': [2,    None, 4,    None, None,None,4,5,None,3],
    'YL': [3,    None, 2,    None, None,None,5,6,None,4],
    'XR': [None, 4,    None, 1,    None,None,None,4,5,4],
    'YR': [None, 3,    None, 5,    None,None,None,3,4,4]
}

df = pd.DataFrame(data)

The simple plots

So with matplotlib

import matplotlib.pyplot as plt
fig, ax = plt.subplots()

# Plot X, Y
ax.plot(df['X'], df['Y'], linestyle='-', marker='o')

# Update plot settings
ax.set_title('Trajectory Plot')
ax.set_xlabel('X-axis')
ax.set_ylabel('Y-axis')

# Show the plot
plt.show()

and with plotly

import plotly.graph_objects as go

# Create a scatter plot
fig = go.Figure(data=go.Scatter(x=df['X'], y=df['Y'], mode='lines+markers'))

# Update layout for better visibility
fig.update_layout(
    title='Trajectory Plot',
    xaxis_title='X-axis',
    yaxis_title='Y-axis',
)

# Show the plot
fig.show()

The problem

I would like to modify the scripts so that I can use a different color depending on the existence or not of the (XL,YL) and (XR,YR) pairs.

  • Grey: none exist
  • Red: Only XL,YL exists
  • Blue: Only XR,YR exists
  • Green: Both exists

In the end it should be like this (pardon the crude picture, I painted over the original blue lines)

How can I add this in matplotlib and plotly?

enter image description here


Solution

  • IIUC you can use matplotlib's LineCollection, see example here

    from matplotlib.collections import LineCollection
    from matplotlib.colors import BoundaryNorm, ListedColormap
    
    # COLOR map
    arr = np.array(['green']*len(df), dtype=str) # both exist # default
    arr[df['XL'].isna() & df['XR'].isna()] = 'grey' # none exist
    arr[~df['XL'].isna() & df['XR'].isna()] = 'red' # only L
    arr[df['XL'].isna() & ~df['XR'].isna()] = 'blue' # only R
    
    # generate line-segments
    points = np.array([df['X'], df['Y']]).T.reshape(-1, 1, 2)
    segments = np.concatenate([points[:-1], points[1:]], axis=1)
    fig, ax = plt.subplots()
    lc = LineCollection(segments, colors=arr)
    
    lc.set_linewidth(2)
    line = ax.add_collection(lc)
    
    # add a scatter for point markers
    ax.scatter(df['X'], df['Y'], c=arr)
    
    ax.set_xlim(df['X'].min()-1, df['X'].max()+1)
    ax.set_ylim(df['Y'].min()-.1, df['Y'].max()+.1)
    plt.show()
    

    Output:

    enter image description here

    For plotly the best solution you can have is:

    import plotly.graph_objects as go
    import itertools as it
    
    # create coordinate  pairs
    x_pairs = it.pairwise(df['X'])
    y_pairs = it.pairwise(df['Y'])
    
    # create base figure
    fig = go.Figure()
    
    # add traces (line segments)
    for x, y, color in zip(x_pairs, y_pairs, arr):
        fig.add_trace(
            go.Scatter(
                x=x,
                y=y, 
                mode='lines+markers', 
                line={'color': color}
            )
        )
        
    fig.update_layout(showlegend=False)
    

    Output:

    enter image description here