Search code examples
python-polars

Advice refactoring polars expr


I have a polars expr, and I cannot use a context 'cause my function has to return a polars expr.

I've implemented a RSI indicator in polars:

rsi_indicator = (100*pl.when(pl.col("close").pct_change() >= 0) \
            .then(pl.col("close").pct_change()) \
                .otherwise(0.0) \
                .rolling_mean(window_size=window) \
                    / (pl.when(pl.col("close").pct_change() >= 0) \
            .then(pl.col("close").pct_change()) \
                .otherwise(0.0).rolling_mean(window_size=window) + \
                    pl.when(pl.col("close").pct_change() < 0) \
            .then(pl.col("close").pct_change()) \
                .otherwise(0.0).abs().rolling_mean(window_size=window))).alias(f"rsi_{window}")

I would refactor this code isolating some quantities in order to easily maintaining code and readability. For example I'd like to define a variable

U = pl.when(pl.col("close").pct_change() >= 0) \
            .then(pl.col("close").pct_change()) \
                .otherwise(0.0) \
                .rolling_mean(window_size=window)

and its fiend V for negative returns in order to return simply 100*U/(U+V) but seems it doesn't work. Any advise?


Solution

  • First, let's use some real financial data, so that our results look reasonable. I'll also change the variable names so that your code will work, as-is.

    import polars as pl
    from yfinance import Ticker
    
    ticker_data = Ticker("AAPL").history(period="3mo")
    
    df = pl.from_pandas(ticker_data)
    df = df.rename({col_nm: col_nm.lower() for col_nm in df.columns})
    df.tail(10)
    
    shape: (10, 7)
    ┌────────────┬────────────┬────────────┬────────────┬──────────┬───────────┬──────────────┐
    │ open       ┆ high       ┆ low        ┆ close      ┆ volume   ┆ dividends ┆ stock splits │
    │ ---        ┆ ---        ┆ ---        ┆ ---        ┆ ---      ┆ ---       ┆ ---          │
    │ f64        ┆ f64        ┆ f64        ┆ f64        ┆ i64      ┆ f64       ┆ f64          │
    ╞════════════╪════════════╪════════════╪════════════╪══════════╪═══════════╪══════════════╡
    │ 227.899994 ┆ 228.0      ┆ 224.130005 ┆ 226.800003 ┆ 37245100 ┆ 0.0       ┆ 0.0          │
    │ 224.5      ┆ 225.690002 ┆ 221.330002 ┆ 221.690002 ┆ 39505400 ┆ 0.0       ┆ 0.0          │
    │ 224.300003 ┆ 225.979996 ┆ 223.25     ┆ 225.770004 ┆ 31855700 ┆ 0.0       ┆ 0.0          │
    │ 225.229996 ┆ 229.75     ┆ 224.830002 ┆ 229.539993 ┆ 33591100 ┆ 0.0       ┆ 0.0          │
    │ 227.779999 ┆ 229.5      ┆ 227.169998 ┆ 229.039993 ┆ 28183500 ┆ 0.0       ┆ 0.0          │
    │ 229.300003 ┆ 229.410004 ┆ 227.339996 ┆ 227.550003 ┆ 31759200 ┆ 0.0       ┆ 0.0          │
    │ 228.699997 ┆ 231.729996 ┆ 228.600006 ┆ 231.300003 ┆ 39882100 ┆ 0.0       ┆ 0.0          │
    │ 233.610001 ┆ 237.490005 ┆ 232.369995 ┆ 233.850006 ┆ 64751400 ┆ 0.0       ┆ 0.0          │
    │ 231.600006 ┆ 232.119995 ┆ 229.839996 ┆ 231.779999 ┆ 34065100 ┆ 0.0       ┆ 0.0          │
    │ 233.440002 ┆ 233.850006 ┆ 230.529999 ┆ 231.860001 ┆ 19355233 ┆ 0.0       ┆ 0.0          │
    └────────────┴────────────┴────────────┴────────────┴──────────┴───────────┴──────────────┘
    

    Next, I'll reformat your code, and construct U and V, per your question.

    window = 10
    rsi_indicator = (
        100
        * pl.when(pl.col("close").pct_change() >= 0)
        .then(pl.col("close").pct_change())
        .otherwise(0.0)
        .rolling_mean(window_size=window)
        / (
            pl.when(pl.col("close").pct_change() >= 0)
            .then(pl.col("close").pct_change())
            .otherwise(0.0)
            .rolling_mean(window_size=window)
            + pl.when(pl.col("close").pct_change() < 0)
            .then(pl.col("close").pct_change())
            .otherwise(0.0)
            .abs()
            .rolling_mean(window_size=window)
        )
    ).alias(f"rsi_{window}")
    
    U = (
        pl.when(pl.col("close").pct_change() >= 0)
        .then(pl.col("close").pct_change())
        .otherwise(0.0)
        .rolling_mean(window_size=window)
    )
    
    V = (
        pl.when(pl.col("close").pct_change() < 0)
        .then(pl.col("close").pct_change())
        .otherwise(0.0)
        .abs()
        .rolling_mean(window_size=window)
    )
    

    We can then express rsi_indicator in terms of U and V as follows:

    df.select(
        pl.col('close'),
        rsi_indicator,
        100 * (U / (U + V)).alias('rsi_UV')
    ).tail(10)
    
    shape: (10, 3)
    ┌────────────┬───────────┬───────────┐
    │ close      ┆ rsi_10    ┆ literal   │
    │ ---        ┆ ---       ┆ ---       │
    │ f64        ┆ f64       ┆ f64       │
    ╞════════════╪═══════════╪═══════════╡
    │ 226.800003 ┆ 46.898378 ┆ 46.898378 │
    │ 221.690002 ┆ 39.997919 ┆ 39.997919 │
    │ 225.770004 ┆ 47.459745 ┆ 47.459745 │
    │ 229.539993 ┆ 55.922485 ┆ 55.922485 │
    │ 229.039993 ┆ 53.166143 ┆ 53.166143 │
    │ 227.550003 ┆ 50.095898 ┆ 50.095898 │
    │ 231.300003 ┆ 47.53084  ┆ 47.53084  │
    │ 233.850006 ┆ 66.012769 ┆ 66.012769 │
    │ 231.779999 ┆ 60.06142  ┆ 60.06142  │
    │ 231.860001 ┆ 62.910393 ┆ 62.910393 │
    └────────────┴───────────┴───────────┘