Search code examples
pythonregressiondata-visualizationscatter-plotaltair

Altair: Regression over a scatter plot coloured with a continuous scale


I'm trying to use Altair to show a scatterplot where the mark colour is given by a non-categorical feature (continuous) and adding a regression line over it. I'm quite new to Altair, though.

First, here's a sample of the data I'm inputting into Altair.

So far I can create either the colouring without a regression:

The problem is that the regression line does not show itself. This was achieved by the following code:

chart = alt.Chart(df).mark_circle(size=60).encode(
   x='Lambda',
   y='ACR',
   color='Consensus',
   tooltip=["Tolerance Prevalence", "Consensus", "ACR", "Lambda"]
)

chart.interactive() + chart.transform_regression(
   'Lambda', 'ACR', method="quad"
).mark_line(color="red")

Or a regression without the desired colour scale. I can do this by simply removing the "color='consensus'" line in the first instruction.

I have tried changing regression methods, and even different feature combinations to no avail. Is there any argument or Altair function I can use to easily fix this?

Thanks in advance!

EDIT 1

Full code:

import altair as alt
import pandas as pd
import numpy as np

df = read_csv("processed.txt")

chart = alt.Chart(df).mark_circle(size=60).encode(
    x='Lambda',
    y='ACR',
    color='Consensus',
    tooltip=["Tolerance Prevalence", "Consensus", "ACR", "Lambda"]
)

chart.interactive() + chart.transform_regression(
    'Lambda', 'ACR', method="quad").mark_line(color="red")

chart

Full Data (processed.txt)


Solution

  • You could define a base chart without the coloring and then build the scatter and regression line from that:

    import altair as alt
    import pandas as pd
    import numpy as np
    
    
    df = pd.read_csv("Downloads/processed.txt", sep=' ')
    
    base = alt.Chart(df).mark_circle(size=60).encode(
        x='Lambda',
        y='ACR',
    )
    
    scatter = base.encode(
        color='Consensus',
        tooltip=["Tolerance Prevalence", "Consensus", "ACR", "Lambda"]
    )
    
    line = base.mark_line(color="red").transform_regression(
        'Lambda', 'ACR', method="quad")
    
    scatter + line
    

    enter image description here