Suppose I am writing a toy tax calculator with two functions:
// calculate the tax amount for a particular income given tax brackets
def tax(income: BigDecimal, brackets: Seq[Bracket]): BigDecimal = ???
// calculate the min. income for a particular tax rate given tax brackets
def income(taxRate: BigDecimal, brackets: Seq[Bracket]) = ???
I define a tax bracket like this:
case class Bracket(maxIncomeOpt: Option[BigDecimal], rate: BigDecimal)
Bracket(Some(BigDecimal(10)), BigDecimal(10))
means a tax bracket of 10%
for income up tp 10
Bracket(Some(BigDecimal(20)), BigDecimal(20))
means a tax bracket of 20%
for income up tp 20
Bracket(None, BigDecimal(30))
means a tax bracket of 30%
for any income
Now I am writing function tax
like this:
def tax(income: BigDecimal, brackets: Seq[Bracket]): BigDecimal = {
val (_, result) = brackets.foldLeft((BigDecimal(0), income)) { case ((result, rest), curr) =>
val taxable = curr.maxIncomeOpt.fold(rest)(_.min(rest))
(result + taxable * curr.rate / 100.0, rest - taxable)
}
result
}
Function tax
seems working but think Seq[Bracket]
is not the best way to define tax brackets. The tax brackets is a sorted sequence of disjoint "back-to-back" intervals with an open interval at the end. How would you define tax brackets ?
Consider solution using algebraic data types to define brackets and PositiveInfinity
to simulate open interval
abstract class TaxBracket(val from: Double, val to: Double, val rate: Double) {
def tax(income: Double) = {
if (income >= from)
if (to.isPosInfinity) (income - from) * rate
else if (income - to > 0) (to - from) * rate
else (income - (from - 1)) * rate
else
0.0
}
}
case object A extends TaxBracket(0, 12500, 0.0)
case object B extends TaxBracket(12501, 50000, 0.2)
case object C extends TaxBracket(50001, 150000, 0.4)
case object D extends TaxBracket(150001, Double.PositiveInfinity, 0.45)
Now tax calculation simplifies to
def tax(income: Double, bands: List[TaxBracket]): Double =
bands.map(_.tax(income)).sum
for example, using UK tax bands defined above we get
tax(60000, List(A, B, C, D)) // res0: Double = 11499.8
which can be verified here.
To get the minimum income for given effective tax rate, try
def income(etr: Double, bands: List[TaxBracket]): Option[Double] = {
bands.map(b => (b.from, b.to)).find { case (from, to) =>
if (to.isPosInfinity) true
else (tax(to, bands) / to) >= etr
}.map { case (lowerBound, upperBound) => lowerBound }
}
income(0.4, List(A, B, C, D)) // res1: Option[Double] = Some(150001.0)