Search code examples
pythonplotlyplotly-python

Discrete colour scale for lines in Plotly graph objects (Python)


I'm mapping lines in plotly.graph_objects.Scattermapbox with fig.add_trace. Ideally, I would like to change the colour of the lines, such that it corresponds to an attribute in another column. For example:

Line Attribute
LINE01 Highway
LINE02 River
LINE03 Highway

I can't figure out how to do this. When I create my plot, each line appears as a legend entry (LINE01 appears as trace 0, etc.) This is my code:

import plotly.graph_objects as go
fig = go.Figure()


for i in range(len(df)):
    fig.add_trace(go.Scattermapbox(
        mode = "lines",
        lon = [df['lon.start'][i], df['lon.end'][i]],
        lat = [df['lat.start'][i], df['lat.end'][i]],
        line = dict(width = 3)
        )
    )

How can I change it such that my legend is grouped by the Attribute column to create a discrete colour scale?


Solution

    • your question did not have sample data for roads and rivers. Have sourced UK rivers and roads from UK gov sources
    • there are two approaches to adding line layers to a plotly figure
      1. as mapbox layers. This carries the advantage that you can utilise plotly and geopandas geojson capabilities to more simply use reference mapping information from 3rd party sources
      2. as traces - the approach you have been using. You have used start and end of these lines which means you loose all coordinates between
    • I have used this answer How to plot visualize a Linestring over a map with Python? on how to generate lines on a mapbox figure.
    • have used a subset of data, just because figure generation time is significant with full set of rivers and roads

    source sample rivers and roads data

    import urllib
    from pathlib import Path
    from zipfile import ZipFile
    import geopandas as gpd
    import pandas as pd
    
    # get some river and road geometry....
    src = [
        {
            "name": "rivers",
            "suffix": ".shp",
            "color": "blue",
            "width": 1.5,
            "url": "https://environment.data.gov.uk/UserDownloads/interactive/023ce3a412b84aca949cad6dcf6c5338191808/EA_StatutoryMainRiverMap_SHP_Full.zip",
        },
        {
            "name": "roads",
            "suffix": ".shp",
            "color": "red",
            "width": 3,
            "url": "https://maps.dft.gov.uk/major-road-network-shapefile/Major_Road_Network_2018_Open_Roads.zip",
        },
    ]
    data = {}
    for s in src:
        f = Path.cwd().joinpath(urllib.parse.urlparse(s["url"]).path.split("/")[-1])
        if not f.exists():
            r = requests.get(s["url"],stream=True,)
            with open(f, "wb") as fd:
                for chunk in r.iter_content(chunk_size=128):
                    fd.write(chunk)
    
        fz = ZipFile(f)
        fz.extractall(f.parent.joinpath(f.stem))
    
        data[s["name"]] = gpd.read_file(
            f.parent.joinpath(f.stem).joinpath(
                [
                    f.filename
                    for f in fz.infolist()
                    if Path(f.filename).suffix == s["suffix"]
                ][0]
            )
        ).assign(source_name=s["name"])
    gdf = pd.concat(data.values()).to_crs("EPSG:4326")
    

    use mapbox layers

    import plotly.graph_objects as go
    import json
    
    # let's work with longer rivers and smaller numer of random roads
    gdf2 = gdf.loc[gdf["length_km"].gt(50).fillna(False) | gdf["roadNumber"].isin(gdf["roadNumber"].fillna("").sample(50).unique())  ]
    fig = go.Figure(go.Scattermapbox())
    
    # use geopandas and plotly geojson layer capabilities,  keep full definition of line strings
    fig.update_layout(
        margin={"l": 0, "r": 0, "t": 0, "b": 0},
        mapbox={
            "style": "carto-positron",
            "zoom": 4,
            "center": {
                "lon": gdf.total_bounds[[0, 2]].mean(),
                "lat": gdf.total_bounds[[1, 3]].mean(),
            },
            "layers": [
                {
                    "source": json.loads(gdf2.loc[gdf2["source_name"].eq(s["name"])].geometry.to_json()),
                    "below": "traces",
                    "type": "line",
                    "color": s["color"],
                    "line": {"width": s["width"]},
                }
                for s in src
            ],
        },
    )
    

    enter image description here

    use mapbox lines

    import numpy as np
    
    # plotly takes array delimited with None between lines. Use numpy padding and shaping to generate this array
    # from pair of features
    def line_array(df, cols):
        return np.pad(
            df.loc[:, cols].values, [(0, 0), (0, 1)], constant_values=None
        ).reshape(1, (len(df) * 3))[0]
    
    
    # map to question columns.... looses all detail of a linestring
    gdf3 = gdf2.join(
        gdf2.geometry.bounds.rename(
            columns={
                "minx": "lon.start",
                "miny": "lat.start",
                "maxx": "lon.end",
                "maxy": "lat.end",
            }
        )
    )
    
    fig = go.Figure(
        [
            go.Scattermapbox(
                name=g[0],
                lat=line_array(g[1], ["lat.start", "lat.end"]),
                lon=line_array(g[1], ["lon.start", "lon.end"]),
                mode="lines",
            )
            for g in gdf3.groupby("source_name")
        ]
    )
    fig.update_layout(
        margin={"l": 0, "r": 0, "t": 15, "b": 0},
        mapbox={
            "style": "carto-positron",
            "zoom": 4,
            "center": {
                "lon": gdf3.loc[:, ["lon.start", "lon.end"]].mean().mean(),
                "lat": gdf3.loc[:, ["lat.start", "lat.end"]].mean().mean(),
            },
        },
    )
    

    enter image description here