Search code examples
pythonplotly

Plotly Express line with continuous color scale


I have the following piece of code

import plotly.express as px
import pandas as pd
import numpy as np

x = [1,2,3,4,5,6]

df = pd.DataFrame(
    {
        'x': x*3,
        'y': list(np.array(x)) + list(np.array(x)**2) + list(np.array(x)**.5),
        'color': list(np.array(x)*0) + list(np.array(x)*0+1) + list(np.array(x)*0+2),
    }
)

for plotting_function in [px.scatter, px.line]:
    fig = plotting_function(
        df,
        x = 'x',
        y = 'y',
        color = 'color',
        title = f'Using {plotting_function.__name__}',
    )
    fig.show()

which produces the following two plots:

enter image description here

enter image description here

For some reason px.line is not producing the continuous color scale that I want, and in the documentation for px.scatter I cannot find how to join the points with lines. How can I produce a plot with a continuous color scale and lines joining the points for each trace?

This is the plot I want to produce: enter image description here


Solution

  • I am not sure this is possible using only plotly.express. If you use px.line, then you can pass the argument markers=True as described in this answer, but from the px.line documentation it doesn't look like continuous color scales are supported.

    UPDATED ANSWER: in order to have both a legend that groups both the lines and markers together, it's probably simpest to use go.Scatter with the argument mode='lines+markers'. You'll need to add the traces one at a time (by plotting each unique color portion of the data one at a time) in order to be able to control each line+marker group from the legend.

    When plotting these traces, you will need some functions to retrieve the colors of the lines from the continuous color scale because go.Scatter won't know what color your lines are supposed to be unless you specify them - thankfully that has been answered here.

    Also you won't be able to generate a colorbar adding the markers one color at a time, so to add a colorbar, you can plot all of the markers at once using go.Scatter, but use the argument marker=dict(size=0, color="rgba(0,0,0,0)", colorscale='Plasma', colorbar=dict(thickness=20)) to display a colorbar, but ensure that these duplicate markers are not visible.

    Putting all of this together:

    # import plotly.express as px
    import plotly.graph_objects as go
    import pandas as pd
    import numpy as np
    
    x = [1,2,3,4,5,6]
    
    df = pd.DataFrame(
        {
            'x': x*3,
            'y': list(np.array(x)) + list(np.array(x)**2) + list(np.array(x)**.5),
            'color': list(np.array(x)*0) + list(np.array(x)*0+1) + list(np.array(x)*0+2),
        }
    )
    
    # This function allows you to retrieve colors from a continuous color scale
    # by providing the name of the color scale, and the normalized location between 0 and 1
    # Reference: https://stackoverflow.com/questions/62710057/access-color-from-plotly-color-scale
    
    def get_color(colorscale_name, loc):
        from _plotly_utils.basevalidators import ColorscaleValidator
        # first parameter: Name of the property being validated
        # second parameter: a string, doesn't really matter in our use case
        cv = ColorscaleValidator("colorscale", "")
        # colorscale will be a list of lists: [[loc1, "rgb1"], [loc2, "rgb2"], ...] 
        colorscale = cv.validate_coerce(colorscale_name)
        
        if hasattr(loc, "__iter__"):
            return [get_continuous_color(colorscale, x) for x in loc]
        return get_continuous_color(colorscale, loc)
            
    
    # Identical to Adam's answer
    import plotly.colors
    from PIL import ImageColor
    
    def get_continuous_color(colorscale, intermed):
        """
        Plotly continuous colorscales assign colors to the range [0, 1]. This function computes the intermediate
        color for any value in that range.
    
        Plotly doesn't make the colorscales directly accessible in a common format.
        Some are ready to use:
        
            colorscale = plotly.colors.PLOTLY_SCALES["Greens"]
    
        Others are just swatches that need to be constructed into a colorscale:
    
            viridis_colors, scale = plotly.colors.convert_colors_to_same_type(plotly.colors.sequential.Viridis)
            colorscale = plotly.colors.make_colorscale(viridis_colors, scale=scale)
    
        :param colorscale: A plotly continuous colorscale defined with RGB string colors.
        :param intermed: value in the range [0, 1]
        :return: color in rgb string format
        :rtype: str
        """
        if len(colorscale) < 1:
            raise ValueError("colorscale must have at least one color")
    
        hex_to_rgb = lambda c: "rgb" + str(ImageColor.getcolor(c, "RGB"))
    
        if intermed <= 0 or len(colorscale) == 1:
            c = colorscale[0][1]
            return c if c[0] != "#" else hex_to_rgb(c)
        if intermed >= 1:
            c = colorscale[-1][1]
            return c if c[0] != "#" else hex_to_rgb(c)
    
        for cutoff, color in colorscale:
            if intermed > cutoff:
                low_cutoff, low_color = cutoff, color
            else:
                high_cutoff, high_color = cutoff, color
                break
    
        if (low_color[0] == "#") or (high_color[0] == "#"):
            # some color scale names (such as cividis) returns:
            # [[loc1, "hex1"], [loc2, "hex2"], ...]
            low_color = hex_to_rgb(low_color)
            high_color = hex_to_rgb(high_color)
    
        return plotly.colors.find_intermediate_color(
            lowcolor=low_color,
            highcolor=high_color,
            intermed=((intermed - low_cutoff) / (high_cutoff - low_cutoff)),
            colortype="rgb",
        )
    
    fig = go.Figure()
    
    ## add the lines+markers
    for color_val in df.color.unique():
        color_val_normalized = (color_val - min(df.color)) / (max(df.color) - min(df.color))
        # print(f"color_val={color_val}, color_val_normalized={color_val_normalized}")
        df_subset = df[df['color'] == color_val]
        fig.add_trace(go.Scatter(
            x=df_subset['x'],
            y=df_subset['y'],
            mode='lines+markers',
            marker=dict(color=get_color('Plasma', color_val_normalized)),
            name=f"line+marker {color_val}",
            legendgroup=f"line+marker {color_val}"
        ))
    
    ## add invisible markers to display the colorbar without displaying the markers
    fig.add_trace(go.Scatter(
        x=df['x'],
        y=df['y'],
        mode='markers',
        marker=dict(
            size=0, 
            color="rgba(0,0,0,0)", 
            colorscale='Plasma', 
            cmin=min(df.color),
            cmax=max(df.color),
            colorbar=dict(thickness=40)
        ),
        showlegend=False
    ))
    
    fig.update_layout(
        legend=dict(
        yanchor="top",
        y=0.99,
        xanchor="left",
        x=0.01),
        yaxis_range=[min(df.y)-2,max(df.y)+2]
    )
    
    fig.show()
    

    enter image description here