Search code examples
pythonmatplotlibjupyter-notebookdata-analysisipywidgets

How to add a color-coded line with extra information


I intend to make a plot of Suspension Travel x Distance of a train. I have a dataset with these information and with information relating if the train is in a straight track or in a curve. I want to mimic the following image, but I have some additional information I have to add (bridge location and tunnel for example). The problem is that what I tried takes almost four minutes to run.

Plot I want to copy
Plot I want to copy

# Plot of the line I want to add
def transition_line(xmin, xmax):
    for i in range((df_irv['Distance'] - xmin).abs().idxmin(), (df_irv['Distance'] - xmax).abs().idxmin()):
        plt.plot([df_irv['Distance'][i], df_irv['Distance'][i+1]], [max(df_irv['SuspTravel']) + 0.5]*2, color='red' if df_irv['Element'][i] == 'CURVA' else 'blue', linewidth=10, alpha=0.5)

# Function to plot the data with adjustable x-axis limits
def plot_graph(xmin, xmax, sensors):
    plt.figure(figsize=(10, 5))
    plt.plot(df_irv['Distance'], df_irv[sensors], label='Suspension Sensor')
    plt.xlim(xmin, xmax)
    plt.xlabel('Distance (Km)')
    plt.ylabel('Suspension Sensor')
    plt.title('Suspension Sensor vs Distance')
    plt.legend()
    plt.grid(True)
    transition_line(xmin, xmax)
    plt.show()

# Create sliders for x-axis limits
xmin_slider = IntSlider(value=0, min=0, max=df_irv['Distance'].max(), step=1, description='X min')
xmax_slider = IntSlider(value=20, min=0, max=df_irv['Distance'].max(), step=1, description='X max')

# Interactive plot
interact(plot_graph, xmin=xmin_slider, xmax=xmax_slider, sensors = ['SuspTravel', 'Roll', 'Bounce'])

Image produced by my attempt
Image produced by my attempt


Solution

  • Calling plt.plot() many times in a loop can be slow. (Also, calculating the same max(df_irv['SuspTravel']) for each step in the loop can be avoided by calculating it once before the start of the loop.)

    To speed up the drawing of the short lines, a similar approach can be used as in multicolored lines code from matplotlib's tutorial. Instead of a loop, numpy's arrays are much faster (arrays are implemented in optimized C code).

    Here is how the code could look like:

    import matplotlib.pyplot as plt
    from matplotlib.collections import LineCollection
    import pandas as pd
    import numpy as np
    
    # create some dummy test data
    df_irv = pd.DataFrame({'Distance': np.random.randint(10, 100, 1000).cumsum(),
                           'SuspTravel': np.random.randn(1000).cumsum() * 100,
                           'Element': np.random.choice(['CURVA', 'RECTA'], 1000, p=[.1, .9])})
    
    fig, ax = plt.subplots()
    
    # add plot
    ax.plot('Distance', 'SuspTravel', data=df_irv)
    
    # add "transition line"
    xmin = df_irv['Distance'].min()
    xmax = df_irv['Distance'].max()
    
    id_xmin = (df_irv['Distance'] - xmin).abs().idxmin()
    id_xmax = (df_irv['Distance'] - xmax).abs().idxmin()
    xvals = df_irv['Distance'][id_xmin:id_xmax + 1]
    yvals = np.full(id_xmax - id_xmin, max(df_irv['SuspTravel']) + 0.5)
    colors = df_irv['Element'][id_xmin:id_xmax].map({'CURVA': 'red', 'RECTA': 'blue'})
    
    segments = np.c_[xvals[:-1], yvals, xvals[1:], yvals].reshape(-1, 2, 2)
    lines = LineCollection(segments, colors=colors)
    lines.set_linewidth(10)
    line = ax.add_collection(lines)
    plt.show()
    

    showing a multi-colored line on top