Search code examples
scalafunctional-programmingpolymorphismtypeclasscase-class

Extend case class with typeclass


I have this code

  case class Salary(employee: String, amount: Double){

  }

  trait XN[M] {
    def x2(m: M): M
    def x3(m: M): M
    def x4(m: M): M
  }

    

I want to extend Salary with XN trait in order for the following test to work:

test("Salary is extended with xN") {

    val bobSalary = Salary("Bob", 100.0)

    bobSalary.x2 shouldBe Salary("Bob", 200.0)
    bobSalary.x3 shouldBe Salary("Bob", 300.0)
    bobSalary.x4 shouldBe Salary("Bob", 400.0)

  }

My attempts:

#1

  implicit val SalaryXN: XN[Salary] = new XN[Salary] {
    override def x2(m: Salary): Salary = m.copy(amount = m.amount * 2)

    override def x3(m: Salary): Salary = m.copy(amount = m.amount * 3)

    override def x4(m: Salary): Salary = m.copy(amount = m.amount * 4)
  }

#2

  object Salary extends XN[Salary] {
    override def x2(m: Salary): Salary = new Salary(employee = m.employee, amount = m.amount * 2)

    override def x3(m: Salary): Salary = new Salary(employee = m.employee, amount = m.amount * 3)

    override def x4(m: Salary): Salary = new Salary(employee = m.employee, amount = m.amount * 4)
  }

How to do that?

Online code


Solution

  • Since it seems that XN is a typeclass, it would be better to properly use that pattern instead of relying on a (discouraged) implicit conversion.

    trait XN[M] {
      def x2(m: M): M
      def x3(m: M): M
      def x4(m: M): M
    }
    
    object XN {
      object syntax {
        implicit class XNOp[M](private val m: M) extends AnyVal {
          @inline final def x2(implicit ev: XN[M]): M = ev.x2(m)
          @inline final def x3(implicit ev: XN[M]): M = ev.x3(m)
          @inline final def x4(implicit ev: XN[M]): M = ev.x4(m)
        }
      }
    }
    
    final case class Salary(employee: String, amount: Double)
    object Salary {
      implicit final val SalaryXN: XN[Salary] =
        new XN[Salary] {
          override def x2(s: Salary): Salary = s.copy(amount = s.amount * 2)
          override def x3(s: Salary): Salary = s.copy(amount = s.amount * 3)
          override def x4(s: Salary): Salary = s.copy(amount = s.amount * 4)
        }
    }
    

    Which can be used like this:

    import XN.syntax._
    
    val bobSalary = Salary("Bob", 100.0)
    
    bobSalary.x2
    // res: Salary = Salary("Bob", 200.0)
    

    You can see the code running here.