Search code examples
scalashapelessidris

Generic Adder from Idris to Scala?


Type Driven Development with Idris presents the following generic adder approach:

AdderType : (numArgs : Nat) -> Type
AdderType Z     = Int
AdderType (S k) = (next : Int) -> AdderType k

adder : (n : Nat) -> (acc : Int) -> AdderType n
adder Z acc     = acc
adder (S k) acc = \x => (adder k (x+acc))

Example:

-- expects 3 Int's to add, with a starting value of 0
*Work> :t (adder 3 0) 
adder 3 0 : Int -> Int -> Int -> Int

-- 0 (initial) + 3 + 3 + 3 == 9
*Work> (adder 3 0) 3 3 3
9 : Int

I'm guessing that shapeless can handle the above generic adder function.

How can it be written in Scala with or without shapeless?


Solution

  • Update: I'll leave my original implementation below, but here's one that's a little more direct:

    import shapeless._
    
    trait AdderType[N <: Nat] extends DepFn1[Int]
    
    object AdderType {
      type Aux[N <: Nat, Out0] = AdderType[N] { type Out = Out0 }
      def apply[N <: Nat](base: Int)(implicit at: AdderType[N]): at.Out = at(base)
    
      implicit val adderTypeZero: Aux[Nat._0, Int] = new AdderType[Nat._0] {
        type Out = Int
        def apply(x: Int): Int = x
      }
    
      implicit def adderTypeSucc[N <: Nat](implicit
        atN: AdderType[N]
      ): Aux[Succ[N], Int => atN.Out] = new AdderType[Succ[N]] {
        type Out = Int => atN.Out
        def apply(x: Int): Int => atN.Out = i => atN(x + i)
      }
    }
    

    And then:

    scala> val at3 = AdderType[Nat._3](0)
    at3: Int => (Int => (Int => Int)) = <function1>
    
    scala> at3(3)(3)(3)
    res8: Int = 9
    

    Original answer below.


    Here's an off-the-cuff Scala translation:

    import shapeless._
    
    trait AdderType[N <: Nat] extends DepFn1[Int] {
      protected def plus(x: Int): AdderType.Aux[N, Out]
    }
    
    object AdderType {
      type Aux[N <: Nat, Out0] = AdderType[N] { type Out = Out0 }
    
      def apply[N <: Nat](base: Int)(implicit at: AdderType[N]): Aux[N, at.Out] =
        at.plus(base)
    
      private[this] case class AdderTypeZero(acc: Int) extends AdderType[Nat._1] {
        type Out = Int
        def apply(x: Int): Int = acc + x
        protected def plus(x: Int): Aux[Nat._1, Int] = copy(acc = acc + x)
      }
    
      private[this] case class AdderTypeSucc[N <: Nat, Out0](
        atN: Aux[N, Out0],
        acc: Int
      ) extends AdderType[Succ[N]] {
        type Out = Aux[N, Out0]
        def apply(x: Int): Aux[N, Out0] = atN.plus(acc + x)
        protected def plus(x: Int): Aux[Succ[N], Aux[N, Out0]] = copy(acc = acc + x)
      }
    
      implicit val adderTypeZero: Aux[Nat._1, Int] = AdderTypeZero(0)
      implicit def adderTypeSucc[N <: Nat](implicit
        atN: AdderType[N]
      ): Aux[Succ[N], Aux[N, atN.Out]] = AdderTypeSucc(atN, 0)
    }
    

    And then:

    scala> val at3 = AdderType[Nat._3](0)
    at3: AdderType[shapeless.Succ[shapeless.Succ[shapeless.Succ[shapeless._0]]]] { ...
    
    scala> at3(3)(3)(3)
    res0: Int = 9
    

    It's more verbose and the representation is a little different to get the Scala syntax to work out—our "base case" is essentially an Int => Int instead of an Int because otherwise I don't see a way to avoid needing to write apply or () everywhere—but the basic ideas are exactly the same.