Search code examples
pythonpandasbisect

Calculate tax liabilities based on a marginal tax rate schedule


income tax calculation python asks how to calculate taxes given a marginal tax rate schedule, and its answer provides a function that works (below).

However, it works only for a single value of income. How would I adapt it to work for a list/numpy array/pandas Series of income values? That is, how do I vectorize this code?

from bisect import bisect

rates = [0, 10, 20, 30]   # 10%  20%  30%

brackets = [10000,        # first 10,000
            30000,        # next  20,000
            70000]        # next  40,000

base_tax = [0,            # 10,000 * 0%
            2000,         # 20,000 * 10%
            10000]        # 40,000 * 20% + 2,000

def tax(income):
    i = bisect(brackets, income)
    if not i:
        return 0
    rate = rates[i]
    bracket = brackets[i-1]
    income_in_bracket = income - bracket
    tax_in_bracket = income_in_bracket * rate / 100
    total_tax = base_tax[i-1] + tax_in_bracket
    return total_tax

Solution

  • This method implements the vectorized marginal tax calculations just using NumPy if it's needed.

    def tax(incomes, bands, rates):
        # Broadcast incomes so that we can compute an amount per income, per band
        incomes_ = np.broadcast_to(incomes, (bands.shape[0] - 1, incomes.shape[0]))
        # Find amounts in bands for each income
        amounts_in_bands = np.clip(incomes_.transpose(),
                                   bands[:-1], bands[1:]) - bands[:-1]
        # Calculate tax per band
        taxes = rates * amounts_in_bands
        # Sum tax bands per income
        return taxes.sum(axis=1)
    

    For usage, bands should include the upper limit - in my view this makes it more explicit.

    incomes = np.array([0, 7000, 14000, 28000, 56000, 77000, 210000])
    bands = np.array([0, 12500, 50000, 150000, np.inf])
    rates = np.array([0, 0.2, 0.4, 0.45])
    
    df = pd.DataFrame()
    df['pre_tax'] = incomes
    df['post_tax'] = incomes - tax(incomes, bands, rates)
    print(df)
    

    Output:

       pre_tax  post_tax
    0        0       0.0
    1     7000    7000.0
    2    14000   13700.0
    3    28000   24900.0
    4    56000   46100.0
    5    77000   58700.0
    6   210000  135500.0