Search code examples
pythonnumpyperformanceoptimizationvectorization

Optimizing the rational quadratic kernel function


Given the following functions, what are some optimizations that can be done to speed up computations?

Yes, I tried using ChatGPT and Bard, the reason I'm mentioning this is that there's a "caveat" to their solutions, np.nan if (index - i) < 0 else price_feed[index - i] has to hold. So for each period, I need to check the previous periods, if on the dataframe that period is index 0, then it is np.nan, since doing a lookup to the past is nonsense, since 1) it is not possible in live envirioments, 2) you would be looking "into the future".

from pandas import DataFrame
import numpy as np
from numba import jit, float32, uint32, float64

@jit(
    float64[:](float64[:], uint32, float32, uint32),
    cache=True,
    parallel=True,
    nopython=True,
    target_backend="cuda",
    forceobj=False,
    nogil=True,
)
def rational_quadratic(
    price_feed: np.ndarray,
    lookback: int,
    relative_weight: float,
    start_at_bar: int,
) -> np.ndarray:

    length_of_prices = len(price_feed)
    bars_calculated = start_at_bar + 1

    result = np.zeros(length_of_prices, dtype=float)
    lookback_squared = np.power(lookback, 2)
    denominator = lookback_squared * 2 * relative_weight

    for index in range(length_of_prices):
        current_weight = 0.0
        cumulative_weight = 0.0

        for i in range(bars_calculated):
            y = np.nan if (index - i) < 0 else price_feed[index - i]
            w = np.power(
                1 + (np.power(i, 2) / denominator),
                -relative_weight,
            )
            current_weight += y * w
            cumulative_weight += w

        result[index] = current_weight / cumulative_weight

    return result


def rational_quadratic_wrapper(
    dataframe: DataFrame,
    lookback: int,
    relative_weight: float,
    start_at_bar: int,
    candle_type: str,
) -> DataFrame:

    dataframe = dataframe.copy()
    
    ohlc4_values = dataframe[candle_type].values
    no_filter_values = rational_quadratic(ohlc4_values, lookback, relative_weight, start_at_bar)
    dataframe["no_filter"] = no_filter_values

    dataframe["yhatdelt2"] = rational_quadratic(
        no_filter_values, lookback, relative_weight, start_at_bar
    )
    dataframe["smooth"] = dataframe["no_filter"] - (dataframe["no_filter"] - dataframe["yhatdelt2"])
    dataframe["zero_lag"] = dataframe["no_filter"] + (
        dataframe["no_filter"] - dataframe["yhatdelt2"]
    )

    return dataframe


fake_price_data = {'ohlc4': [4308.172, 4175.935, 4070.76, 4112.74, 4029.135, 4308.172, 4175.935, 4070.76, 4112.74, 4029.135, 4308.172, 4175.935, 4070.76, 4112.74, 4029.135, 4308.172, 4175.935, 4070.76, 4112.74, 4029.135, 4308.172, 4175.935, 4070.76, 4112.74, 4029.135, 4029.135, 4308.172, 4175.935, 4070.76, 4112.74, 4029.135, 4308.172, 4175.935, 4070.76, 4112.74, 4029.135]}
dates = pd.date_range(start='2017-08-17', periods=36, freq='D')
df = pd.DataFrame(fake_data, index=dates)

results = rational_quadratic_optimized(debug_df, 8, 1, 5, "ohlc4")
print(results)

# Make sure your optimization matches the results dataframe output

# I tried using ChatGPT and Bard, all of their solutions break the constraints or generate a ```shape error``` when doing cross multiplications with the ```relative weight```. Wish you luck in the challenge!

Solution

  • You can use to optimize this operation:

    from numba import njit, prange
    
    
    @njit(parallel=True)
    def rational_quadratic_numba(price_feed, out, lookback, relative_weight, start_at_bar):
        lookback_squared = np.power(lookback, 2)
        denominator = lookback_squared * 2 * relative_weight
    
        for i in prange(start_at_bar, len(price_feed)):
            current_weight = 0.0
            cumulative_weight = 0.0
    
            for j in range(start_at_bar + 1):
                y = price_feed[i - j]
                w = np.power(
                    1 + (np.power(j, 2) / denominator),
                    -relative_weight,
                )
                current_weight += y * w
                cumulative_weight += w
    
            out[i] = current_weight / cumulative_weight
    

    The whole script (only no_filter_2 is computed with new function for simplicity):

    import numpy as np
    from pandas import DataFrame
    
    from numba import njit, prange
    
    
    @njit(parallel=True)
    def rational_quadratic_numba(price_feed, out, lookback, relative_weight, start_at_bar):
        lookback_squared = np.power(lookback, 2)
        denominator = lookback_squared * 2 * relative_weight
    
        for i in prange(start_at_bar, len(price_feed)):
            current_weight = 0.0
            cumulative_weight = 0.0
    
            for j in range(start_at_bar + 1):
                y = price_feed[i - j]
                w = np.power(
                    1 + (np.power(j, 2) / denominator),
                    -relative_weight,
                )
                current_weight += y * w
                cumulative_weight += w
    
            out[i] = current_weight / cumulative_weight
    
    
    def rational_quadratic(
        price_feed: np.ndarray,
        lookback: int,
        relative_weight: float,
        start_at_bar: int,
    ) -> np.ndarray:
        length_of_prices = len(price_feed)
        bars_calculated = start_at_bar + 1
    
        result = np.zeros(length_of_prices, dtype=float)
        lookback_squared = np.power(lookback, 2)
        denominator = lookback_squared * 2 * relative_weight
    
        for index in range(length_of_prices):
            current_weight = 0.0
            cumulative_weight = 0.0
    
            for i in range(bars_calculated):
                y = np.nan if (index - i) < 0 else price_feed[index - i]
                w = np.power(
                    1 + (np.power(i, 2) / denominator),
                    -relative_weight,
                )
                current_weight += y * w
                cumulative_weight += w
    
            result[index] = current_weight / cumulative_weight
    
        return result
    
    
    def rational_quadratic_wrapper(
        dataframe: DataFrame,
        lookback: int,
        relative_weight: float,
        start_at_bar: int,
        candle_type: str,
    ) -> DataFrame:
        dataframe = dataframe.copy()
    
        ohlc4_values = dataframe[candle_type].values
    
        no_filter_values = rational_quadratic(
            ohlc4_values, lookback, relative_weight, start_at_bar
        )
        dataframe["no_filter"] = no_filter_values
    
        dataframe["no_filter_2"] = np.nan
        rational_quadratic_numba(
            ohlc4_values,
            dataframe["no_filter_2"].values,
            lookback,
            relative_weight,
            start_at_bar,
        )
    
        dataframe["yhatdelt2"] = rational_quadratic(
            no_filter_values, lookback, relative_weight, start_at_bar
        )
        dataframe["smooth"] = dataframe["no_filter"] - (
            dataframe["no_filter"] - dataframe["yhatdelt2"]
        )
        dataframe["zero_lag"] = dataframe["no_filter"] + (
            dataframe["no_filter"] - dataframe["yhatdelt2"]
        )
    
        return dataframe
    
    
    fake_price_data = {
        "ohlc4": [
            4308.172,
            4175.935,
            4070.76,
            4112.74,
            4029.135,
            4308.172,
            4175.935,
            4070.76,
            4112.74,
            4029.135,
            4308.172,
            4175.935,
            4070.76,
            4112.74,
            4029.135,
            4308.172,
            4175.935,
            4070.76,
            4112.74,
            4029.135,
            4308.172,
            4175.935,
            4070.76,
            4112.74,
            4029.135,
            4029.135,
            4308.172,
            4175.935,
            4070.76,
            4112.74,
            4029.135,
            4308.172,
            4175.935,
            4070.76,
            4112.74,
            4029.135,
        ]
    }
    dates = pd.date_range(start="2017-08-17", periods=36, freq="D")
    df = pd.DataFrame(fake_price_data, index=dates)
    
    results = rational_quadratic_wrapper(df, 8, 1, 5, "ohlc4")
    print(results)
    

    Prints:

                   ohlc4    no_filter  no_filter_2    yhatdelt2       smooth     zero_lag
    2017-08-17  4308.172          NaN          NaN          NaN          NaN          NaN
    2017-08-18  4175.935          NaN          NaN          NaN          NaN          NaN
    2017-08-19  4070.760          NaN          NaN          NaN          NaN          NaN
    2017-08-20  4112.740          NaN          NaN          NaN          NaN          NaN
    2017-08-21  4029.135          NaN          NaN          NaN          NaN          NaN
    2017-08-22  4308.172  4164.845721  4164.845721          NaN          NaN          NaN
    2017-08-23  4175.935  4146.820891  4146.820891          NaN          NaN          NaN
    2017-08-24  4070.760  4129.994743  4129.994743          NaN          NaN          NaN
    2017-08-25  4112.740  4135.491538  4135.491538          NaN          NaN          NaN
    2017-08-26  4029.135  4119.589106  4119.589106          NaN          NaN          NaN
    2017-08-27  4308.172  4164.845721  4164.845721  4143.152441  4143.152441  4186.539001
    2017-08-28  4175.935  4146.820891  4146.820891  4140.761712  4140.761712  4152.880071
    2017-08-29  4070.760  4129.994743  4129.994743  4138.115818  4138.115818  4121.873668
    2017-08-30  4112.740  4135.491538  4135.491538  4138.839640  4138.839640  4132.143436
    2017-08-31  4029.135  4119.589106  4119.589106  4135.872389  4135.872389  4103.305823
    2017-09-01  4308.172  4164.845721  4164.845721  4143.152441  4143.152441  4186.539001
    2017-09-02  4175.935  4146.820891  4146.820891  4140.761712  4140.761712  4152.880071
    2017-09-03  4070.760  4129.994743  4129.994743  4138.115818  4138.115818  4121.873668
    2017-09-04  4112.740  4135.491538  4135.491538  4138.839640  4138.839640  4132.143436
    2017-09-05  4029.135  4119.589106  4119.589106  4135.872389  4135.872389  4103.305823
    2017-09-06  4308.172  4164.845721  4164.845721  4143.152441  4143.152441  4186.539001
    2017-09-07  4175.935  4146.820891  4146.820891  4140.761712  4140.761712  4152.880071
    2017-09-08  4070.760  4129.994743  4129.994743  4138.115818  4138.115818  4121.873668
    2017-09-09  4112.740  4135.491538  4135.491538  4138.839640  4138.839640  4132.143436
    2017-09-10  4029.135  4119.589106  4119.589106  4135.872389  4135.872389  4103.305823
    2017-09-11  4029.135  4115.210393  4115.210393  4134.323269  4134.323269  4096.097517
    2017-09-12  4308.172  4121.092758  4121.092758  4127.424441  4127.424441  4114.761075
    2017-09-13  4175.935  4123.912211  4123.912211  4123.931166  4123.931166  4123.893256
    2017-09-14  4070.760  4123.022699  4123.022699  4122.861083  4122.861083  4123.184315
    2017-09-15  4112.740  4123.049833  4123.049833  4121.113987  4121.113987  4124.985679
    2017-09-16  4029.135  4119.589106  4119.589106  4121.096805  4121.096805  4118.081408
    2017-09-17  4308.172  4164.845721  4164.845721  4129.714298  4129.714298  4199.977144
    2017-09-18  4175.935  4146.820891  4146.820891  4134.182404  4134.182404  4159.459379
    2017-09-19  4070.760  4129.994743  4129.994743  4135.111035  4135.111035  4124.878451
    2017-09-20  4112.740  4135.491538  4135.491538  4136.988124  4136.988124  4133.994952
    2017-09-21  4029.135  4119.589106  4119.589106  4135.872389  4135.872389  4103.305823
    

    Benchmark:

    from timeit import timeit
    
    dates = pd.date_range(start="2017-08-17", periods=36_000, freq="D")
    df = pd.DataFrame({"ohlc4": 50 + np.random.random(len(dates)) * 100}, index=dates)
    
    t1 = timeit(
        "rational_quadratic(df['ohlc4'].values, 8, 1, 5)", number=1, globals=globals()
    )
    
    t2 = timeit(
        "df['no_filter_2']=np.nan;rational_quadratic_numba(df['ohlc4'].values, df['no_filter_2'].values, 8, 1, 5)",
        number=1,
        globals=globals(),
    )
    
    print(f"Time normal = {t1}")
    print(f"Time numba =  {t2}")`
    

    This prints on my computer AMD 5700x:

    Time normal = 0.3848473190009827
    Time numba =  0.000804967996373307