Search code examples
scalanumeric

Simple tax calculation in Scala


Suppose I am writing a toy tax calculator with two functions:

// calculate the tax amount for a particular income given tax brackets
def tax(income: BigDecimal, brackets: Seq[Bracket]): BigDecimal = ???

// calculate the min. income for a particular tax rate given tax brackets 
def income(taxRate: BigDecimal, brackets: Seq[Bracket]) = ???

I define a tax bracket like this:

case class Bracket(maxIncomeOpt: Option[BigDecimal], rate: BigDecimal)

Bracket(Some(BigDecimal(10)), BigDecimal(10)) means a tax bracket of 10% for income up tp 10 Bracket(Some(BigDecimal(20)), BigDecimal(20)) means a tax bracket of 20% for income up tp 20
Bracket(None, BigDecimal(30)) means a tax bracket of 30% for any income

Now I am writing function tax like this:

def tax(income: BigDecimal, brackets: Seq[Bracket]): BigDecimal = {
  val (_, result) = brackets.foldLeft((BigDecimal(0), income)) { case ((result, rest), curr) =>
    val taxable = curr.maxIncomeOpt.fold(rest)(_.min(rest))
    (result + taxable * curr.rate / 100.0, rest - taxable)
  }
  result
} 

Function tax seems working but think Seq[Bracket] is not the best way to define tax brackets. The tax brackets is a sorted sequence of disjoint "back-to-back" intervals with an open interval at the end. How would you define tax brackets ?


Solution

  • Consider solution using algebraic data types to define brackets and PositiveInfinity to simulate open interval

    abstract class TaxBracket(val from: Double, val to: Double, val rate: Double) {
      def tax(income: Double) = {
        if (income >= from)
          if (to.isPosInfinity) (income - from) * rate
          else if (income - to > 0) (to - from) * rate
          else (income - (from - 1)) * rate
        else
          0.0
      }
    }
    case object A extends TaxBracket(0, 12500, 0.0)
    case object B extends TaxBracket(12501, 50000, 0.2)
    case object C extends TaxBracket(50001, 150000, 0.4)
    case object D extends TaxBracket(150001, Double.PositiveInfinity, 0.45)
    

    Now tax calculation simplifies to

    def tax(income: Double, bands: List[TaxBracket]): Double =
      bands.map(_.tax(income)).sum
    

    for example, using UK tax bands defined above we get

    tax(60000, List(A, B, C, D)) // res0: Double = 11499.8
    

    which can be verified here.

    To get the minimum income for given effective tax rate, try

    def income(etr: Double, bands: List[TaxBracket]): Option[Double] = {
      bands.map(b => (b.from, b.to)).find { case (from, to) =>
        if (to.isPosInfinity) true
        else (tax(to, bands) / to) >= etr
      }.map { case (lowerBound, upperBound) => lowerBound }
    }
    
    income(0.4, List(A, B, C, D)) // res1: Option[Double] = Some(150001.0)