Search code examples
pythonaltair

Calculate new series based on data in filtered altair chart


I have a timeseries dataframe which I'm using to create a heatmap (showing the difference on the most recent observation between the two relevant values) and two line charts -- one showing the series themselves and another showing the difference between the series. I'm able to make the first two charts, but I can't figure out how to calculate the spread between the two selected series (using a .selection_point()) on the fly. Code snippets below.

import pandas as pd
import numpy as np
import altair as alt

ex_ts = pd.DataFrame(
    np.random.random((10, 5)),
    columns=['a', 'b', 'c', 'd', 'e'],
    index=(
        pd.date_range(
            start=pd.to_datetime('today')-pd.Timedelta(9, unit='D'), 
            end=pd.to_datetime('today')).strftime('%Y-%m-%d')
        )
)

ex_ts_long = ex_ts.stack().reset_index().set_axis(
    ['date', 'category', 'diff'],
    axis=1
).assign(
    x = lambda a: a['category'],
    y = lambda a: a['category']
)

print(ex_ts_long.head())

#          date category      diff  x  y
# 0  2024-03-03        a  0.910670  a  a
# 1  2024-03-03        b  0.608069  b  b
# 2  2024-03-03        c  0.797001  c  c
# 3  2024-03-03        d  0.139386  d  d
# 4  2024-03-03        e  0.147499  e  e

def get_last_diff(i):
    return ex_ts.sub(ex_ts.iloc[:,i], axis=0).iloc[-1,:]

ex_z = pd.concat(
    [get_last_diff(i) for i in np.arange(0, 5)],
    axis=1
).set_axis(
    ex_ts.columns,
    axis=1).stack().reset_index().set_axis(
    ['x', 'y', 'diff'],
    axis=1
).round(2)

print(ex_z.head())

#    x  y  diff
# 0  a  a  0.00
# 1  a  b -0.29
# 2  a  c -0.16
# 3  a  d -0.27
# 4  a  e -0.38

select_x = alt.selection_point(fields=['x'], name='select_x')
select_y = alt.selection_point(fields=['y'], name='select_y')

base = alt.Chart(ex_z).encode(
    x='x',
    y='y',
    color='diff'
).add_params(
    select_x
).add_params(
    select_y
).properties(
    width=500,
    height=500
)

hmap = base.mark_rect()
text = base.mark_text(fontWeight='bold').encode(
    text='diff',
    color=alt.value('red')
)
hmap_chart = (hmap + text)

line_1 = alt.Chart(ex_ts_long).mark_line().encode(
    x='date',
    y='diff',
    color='category'
).transform_filter(select_x | select_y)

tmp = alt.vconcat(hmap_chart, line_1)

The above code works to create a heatmap which you can click to filter the chart on the bottom. The problem, however, is that I want to calculate the difference between the two series in the first line chart, and plot it.

The most promising attempt was to create a new chart by adding together two filtered charts. I aggregated within in the two filtered charts so that I could reference the new variables to create the variable I'm looking for, but that didn't seem to work. More example code below.

rhs_line1 = alt.Chart(df_long).mark_line().transform_filter(
    select_y
).transform_aggregate(
    agg_y = 'sum(spread)',
    groupby=['date']
).encode(x='date:T', y='agg_y:Q')

rhs_line2 = alt.Chart(df_long).mark_line().transform_filter(
    select_x
).transform_aggregate(
    agg_x = 'sum(spread)',
    groupby=['date']
).encode(
    x='date:T',
    y='agg_x:Q'
)

rhs_line =(rhs_line1 + rhs_line1).transform_calculate(
    spread = 'datum.agg_y - datum.agg_x'
).encode(
    x='date:T',
    y='spread:Q'
)

final = alt.vconcat(hmap_chart, alt.hconcat(line_1, rhs_line))

Solution

  • It can be quite tricky to troubleshoot what is happening with the data when transforms are involved. I can be really helpful to either use the .transformed_data() method on the chart in altair, or click the three dots menu in the top right, then "open chart in vega editor" and look in the "Data viewer" tab how the data is changing as you make selections.

    Doing this, I could see that spread is actually undefined in your case since the layered charts' data is not merged together; instead there are two data frames, one that contains agg_y and one that contains agg_x, so it is not possible to do agg_x - agg_y in the calculate transform. I'm unsure if there is a way to merge these, but in either way, I don't think what you want to go about it this way due to how your data is structured.

    Since you want to subtract two columns from each other, I think it is simpler to pivot the dataframe and then dynamically identify the columns to subtract based on the selection, which you can do as follows:

    dynamic_title = alt.Title(alt.expr(f'"Difference between " + {select_x.name}.x + " and " + {select_y.name}.y'))
    rhs_line = alt.Chart(ex_ts_long, title=dynamic_title).transform_pivot(
        'category', 'diff', groupby=['date']
    ).transform_calculate(
        spread = f'datum[{select_x.name}.x] - datum[{select_y.name}.y]'
    ).mark_line(color='grey').encode(
        x='date',
        y=alt.Y('spread:Q').scale(domain=(-1, 1)),
    )
    
    alt.vconcat(
        hmap_chart,
        alt.hconcat(line_1, rhs_line).resolve_scale(color='independent')
    )
    

    enter image description here