Search code examples
pythonstreamlitaltair

Altair plot in streamlit: How to add a legend?


I'm using streamlit and need altair to plot (because of available interpolation options).

Given this simple code:

import streamlit as st
import altair as alt
import pandas as pd

data = pd.DataFrame({"x": [0, 1, 2, 3], "y": [0, 10, 15, 20], "z": [10, 8, 10, 1]})

base = alt.Chart(data.reset_index()).encode(x="x")

chart = alt.layer(
    base.mark_line(color="red").encode(y="y", color=alt.value("green")),
    base.mark_line(color="red").encode(y="z", color=alt.value("red")),
).properties(title="My plot",)

st.altair_chart(chart, theme="streamlit", use_container_width=True)

Which results in this plot: visualiization of plot

What's the correct way to add a legend next to the plot?

In documentation, I see the legend option as part of "Color", but this seems always be related to visualize another dimension. In my case, I just want to plot different lines and have a legend with their respective colors.


Solution

  • Transform your data into long-dataframe format. This format is more suitable for creating legends in Altair because each line will be associated with a category. Then use this category for color encoding:

    import streamlit as st
    import altair as alt
    import pandas as pd
    alt.renderers.enable("html")
    
    # Your data
    data = pd.DataFrame({
        "x": [0, 1, 2, 3],
        "y": [0, 10, 15, 20],
        "z": [10, 8, 10, 1]
    })
    
    
    # Transform data to long format
    data_long = pd.melt(data, id_vars=['x'], value_vars=['y', 'z'], var_name='category', value_name='y,z')
    
    # Create an Altair chart
    chart = alt.Chart(data_long).mark_line().encode(
        x='x',
        y='y,z',
        color='category:N'  # Use the category field for color encoding
    ).properties(
        title="My plot"
    )
    
    
    st.altair_chart(chart, use_container_width=True)
    

    Output:

    enter image description here