Search code examples
pythonaltair

Altair changes order of legend when saving data to disk vs. storing data in the plot


Consider the following dataframe created with the code snipped attached at the end of this question:

    A           B           C
0   1.624345    -0.611756   2
1   -1.072969   0.865408    1
2   1.744812    -0.761207   2
3   -0.249370   1.462108    2
4   -0.322417   -0.384054   11
5   -1.099891   -0.172428   1
6   0.042214    0.582815    1
7   1.144724    0.901591    11
8   0.900856    -0.683728   11
9   -0.935769   -0.267888   2

I want to create a scatter plot using the A and B columns and color the plot by C. Running

alt.Chart(randf).mark_point(size=100, filled=True).encode(
    x="A", y="B", color="C:N"
).properties(width=200, height=200)

I get the following chart with the legend being sorted 1, 2, 11:

enter image description here

But, I am often in a situation where the data is large and I don't want to embed it into the plot and store it on disk instead. I then usually run

import os
from toolz.curried import pipe

def csv_dir(data, data_dir="altairdata") -> None:
    os.makedirs(data_dir, exist_ok=True)
    return pipe(data, alt.to_csv(filename=data_dir + "/{prefix} -{hash}.{extension}"))

alt.data_transformers.register("csv_dir", csv_dir)
alt.data_transformers.enable("csv_dir", data_dir="./.temporary_altair_data/")

Rerunning the same plot again, it now looks like this: enter image description here

I.e., the order of the color is now 1, 11, 2. More generally, it treats the numbers like words and tries to sort them alphanumerically (at least that is my observation).

Is there a way to have the correct order of the legend while still saving data to disk?

Full code to reproduce:

import pandas as pd
import altair as alt
import numpy as np 
import os
from toolz.curried import pipe


np.random.seed(1)
randf = pd.DataFrame(np.random.randn(10, 3), columns=list("ABC"))
randf["C"] = np.random.choice([1, 2, 11], 10)

chart1 = alt.Chart(randf).mark_point(size=100, filled=True).encode(
    x="A", y="B", color="C:N"
)
chart1 

followed by


def csv_dir(data, data_dir="altairdata") -> None:
    os.makedirs(data_dir, exist_ok=True)
    return pipe(data, alt.to_csv(filename=data_dir + "/{prefix} -{hash}.{extension}"))

alt.data_transformers.register("csv_dir", csv_dir)
alt.data_transformers.enable("csv_dir", data_dir="./.temporary_altair_data/")

chart2 = alt.Chart(randf).mark_point(size=100, filled=True).encode(
    x="A", y="B", color="C:N"
)
chart2

Solution

  • This seems like a weird bug. But I managed to find a workaround. The trick is to use transform_filter by passing possible values of C one at a time.

    def csv_dir(data, data_dir="altairdata") -> None:
        os.makedirs(data_dir, exist_ok=True)
        return pipe(data, alt.to_csv(filename=data_dir + "/{prefix} -{hash}.{extension}"))
    alt.data_transformers.register("csv_dir", csv_dir)
    alt.data_transformers.enable("csv_dir", data_dir="./.temporary_altair_data/")
    chart2 = alt.Chart(randf).mark_point(size=100, filled=True).encode(
        x=alt.X("A:Q"), 
        y=alt.Y("B:Q"), 
        color=alt.Color('C:N'),
    ).transform_filter( # get only the possible values of C
        {'field': 'C', 'oneOf': np.sort(randf.C.unique())}
    )
    chart2
    

    enter image description here