Search code examples
scalavariadic-functionscurryingscala-3

Express function of arbitrary arity in vanilla Scala 3


Trying to grasp Scala 3 type system. Question:

  • Is it possible to write a single universal def curry(f: ???) = ... function that accepts f of any arity and returns a curried fn? No compiler plugins, no external fancy libs, just a function of N-arity expressed in plain Scala 3?
  • I look at this Haskell example https://riptutorial.com/haskell/example/18470/an-n-arity-curry that does smth similar to what is needed.

(purpose of this question is not to use any external lib - purpose is to learn functional programming concepts with Scala 3 as a tool. Got a feeling that this might be related to dealing with args as tuples or some conversion of fn to tupled fn ? i feel there is some symmetry between fn args and a concept of tuple ?)


Solution

  • On contrary to Haskell, in Scala there are different functional types (X1, ..., Xn) => Y (aka FunctionN[X1, ..., Xn, Y]) and ((X1, ..., Xn)) => Y (aka Function1[TupleN[X1, ..., Xn], Y]). For the latter (in order to transform them into X1 => ... => Xn => Y aka Function1[X1, Function1[..., Function1[Xn, Y]...]]) you can use match types, inline methods, and compile-time operations

    import scala.compiletime.{erasedValue, summonFrom}
    
    type Reverse[T <: Tuple] = ReverseLoop[T, EmptyTuple]
    
    inline def reverse[T <: Tuple](t: T): Reverse[T] = reverseLoop(t, EmptyTuple)
    
    type ReverseLoop[T <: Tuple, S <: Tuple] <: Tuple = T match
      case EmptyTuple => S
      case h *: t => ReverseLoop[t, h *: S]
    
    inline def reverseLoop[T <: Tuple, S <: Tuple](x: T, acc: S): ReverseLoop[T, S] =
      inline x match
        case _: EmptyTuple => acc
        case x: (_ *: _) => x match
          case h *: t => reverseLoop(t, h *: acc)
    
    type Curry[T <: Tuple, Y] = CurryLoop[T, T, EmptyTuple, Y]
    
    inline def curry[T <: Tuple, Y](f: T => Y): Curry[T, Y] =
      curryLoop[T, T, EmptyTuple, Y](f, EmptyTuple)
    
    type CurryLoop[T1 <: Tuple, T <: Tuple, S <: Tuple, Y] = T1 match
      case EmptyTuple => Y
      case h *: t => h => CurryLoop[t, T, h *: S, Y]
    
    inline def curryLoop[T1 <: Tuple, T <: Tuple, S <: Tuple, Y](
      f: T => Y,
      acc: S
    ): CurryLoop[T1, T, S, Y] = inline erasedValue[T1] match
      case _: EmptyTuple => summonFrom {
        case _: (Reverse[S] =:= T) => f(reverse(acc))
      }
      case _: (h *: t) => (h: h) => curryLoop[t, T, h *: S, Y](f, h *: acc)
    

    Testing:

    // compiles
    summon[Curry[(Int, String, Boolean), String] =:= (Int => String => Boolean => String)]
    
    val f: ((Int, String, Boolean)) => String = t => s"${t._1}, ${t._2}, ${t._3}"
    val g = curry(f)
    g: (Int => String => Boolean => String) // checking the type
    g(1)("a")(true) // 1, a, true
    

    Scala 3: typed tuple zipping


    Alternatively, you can still use good old type classes

    trait Reverse[T <: Tuple]:
      type Out <: Tuple
      def apply(t: T): Out
    
    object Reverse:
      type Aux[T <: Tuple, Out0 <: Tuple] = Reverse[T] {type Out = Out0}
      def instance[T <: Tuple, Out0 <: Tuple](f: T => Out0): Aux[T, Out0] =
        new Reverse[T]:
          override type Out = Out0
          override def apply(t: T): Out = f(t)
    
      given [T <: Tuple](using
        reverseLoop: ReverseLoop[T, EmptyTuple]
      ): Aux[T, reverseLoop.Out] = instance(t => reverseLoop(t, EmptyTuple))
    
    trait ReverseLoop[T <: Tuple, S <: Tuple]:
      type Out <: Tuple
      def apply(t: T, acc: S): Out
    
    object ReverseLoop:
      type Aux[T <: Tuple, S <: Tuple, Out0 <: Tuple] =
        ReverseLoop[T, S] {type Out = Out0}
      def instance[T <: Tuple, S <: Tuple, Out0 <: Tuple](
        f: (T, S) => Out0
      ): Aux[T, S, Out0] = new ReverseLoop[T, S]:
        override type Out = Out0
        override def apply(t: T, acc: S): Out = f(t, acc)
    
      given [S <: Tuple]: Aux[EmptyTuple, S, S] = instance((_, acc) => acc)
    
      given [H, T <: Tuple, S <: Tuple](using
        reverseLoop: ReverseLoop[T, H *: S]
      ): Aux[H *: T, S, reverseLoop.Out] =
        instance((l, acc) => reverseLoop(l.tail, l.head *: acc))
    
    trait Curry[T <: Tuple, Y]:
      type Out
      def apply(f: T => Y): Out
    
    object Curry:
      type Aux[T <: Tuple, Y, Out0] = Curry[T, Y] {type Out = Out0}
      def instance[T <: Tuple, Y, Out0](g: (T => Y) => Out0): Aux[T, Y, Out0] =
        new Curry[T, Y]:
          override type Out = Out0
          override def apply(f: T => Y): Out = g(f)
    
      given [T <: Tuple, Y](using
        curryLoop: CurryLoop[T, T, EmptyTuple, Y]
      ): Aux[T, Y, curryLoop.Out] = instance(f => curryLoop(f, EmptyTuple))
    
    trait CurryLoop[T1 <: Tuple, T <: Tuple, S <: Tuple, Y]:
      type Out
      def apply(f: T => Y, acc: S): Out
    
    object CurryLoop:
      type Aux[T1 <: Tuple, T <: Tuple, S <: Tuple, Y, Out0] =
        CurryLoop[T1, T, S, Y] {type Out = Out0}
      def instance[T1 <: Tuple, T <: Tuple, S <: Tuple, Y, Out0](
        g: (T => Y, S) => Out0
      ): Aux[T1, T, S, Y, Out0] = new CurryLoop[T1, T, S, Y]:
        override type Out = Out0
        override def apply(f: T => Y, acc: S): Out = g(f, acc)
    
      given [S <: Tuple, Y](using
        reverse: Reverse[S]
      ): Aux[EmptyTuple, reverse.Out, S, Y, Y] =
        instance((f, acc) => f(reverse(acc)))
    
      given [H1, T1 <: Tuple, T <: Tuple, S <: Tuple, Y](using
        curryLoop: CurryLoop[T1, T, H1 *: S, Y]
      ): Aux[H1 *: T1, T, S, Y, H1 => curryLoop.Out] =
        instance((f, acc) => h1 => curryLoop(f, h1 *: acc))
    
    def curry[T <: Tuple, Y](f: T => Y)(using
      curryInst: Curry[T, Y]
    ): curryInst.Out = curryInst(f)
    

    Testing:

    // compiles
    summon[Curry.Aux[(Int, String, Boolean), String, Int => String => Boolean => String]]
    
    val c = summon[Curry[(Int, String, Boolean), String]]  // compiles
    summon[c.Out =:= (Int => String => Boolean => String)] // compiles
    
    val f: ((Int, String, Boolean)) => String = t => s"${t._1}, ${t._2}, ${t._3}"
    val g = curry(f)
    g: (Int => String => Boolean => String) // checking the type
    g(1)("a")(true) // 1, a, true
    

    A method tupled transforming (X1, ..., Xn) => Y into ((X1, ..., Xn)) => Y can be implemented as a transparent macro. A macro being transparent (this corresponds to whitebox in Scala 2) means that it can return a type more precise than declared.

    import scala.quoted.*
    
    transparent inline def tupled[F](f: F): Any = ${tupledImpl('f)}
    
    def tupledImpl[F: Type](f: Expr[F])(using Quotes): Expr[Any] =
      import quotes.reflect.*
    
      val allTypeArgs = TypeRepr.of[F].typeArgs
      val argTypes    = allTypeArgs.init
      val argCount    = argTypes.length
      val returnType  = allTypeArgs.last
    
      val tupleType = AppliedType(
        TypeRepr.typeConstructorOf(Class.forName(s"scala.Tuple$argCount")),
        argTypes
      )
    
      (tupleType.asType, returnType.asType) match
        case ('[t], '[b]) => '{
          (a: t) => ${
            Apply(
              Select.unique(f.asTerm, "apply"),
              (1 to argCount).toList.map(i => Select.unique('a.asTerm, s"_$i"))
            ).asExprOf[b]
          }
        }
    

    Testing:

    val f: (Int, String, Boolean) => String = (i, s, b) => s"$i, $s, $b"
    val g = tupled(f)
    g: (((Int, String, Boolean)) => String) // checking the type
    g((1, "a", true)) // 1, a, true
    

    This gives us curry for types (X1, ..., Xn) => Y

    curry(tupled(f))(1)("a")(true) // 1, a, true
    

    Although curry(tupled(f)) works for a specific f it's not easy to specify the signature of a method (composing curry and tupled)

    // for match-type implementation of curry
    
    transparent inline def curry1[F](f: F): Any = curry(tupled(f))
    
    curry1(f)(1)("a")(true)
    // doesn't compile: method curry1 ... does not take more parameters
    
    // for type-class implementation of curry
    
    transparent inline def curry1[F](f: F): Any = curry(tupled(f))
    // doesn't compile: No given instance of type Curry[Nothing, Any] was found...
    // (and what types to specify in (using Curry[???, ???]) ?)
    

    I thought that Recovering precise types using patterns should help if I make curry1 a macro too

    transparent inline def curry1[F](f: F): Any = ${curry1Impl[F]('f)}
    
    def curry1Impl[F: Type](f: Expr[F])(using Quotes): Expr[Any] =
      import quotes.reflect.*
    
      '{ tupled[F]($f) } match
        case
          '{
            type t <: Tuple
            $x: (`t` => y)
          } =>
            Expr.summon[Curry[t, y]] match
              case Some(c) => '{curry[t, y]($x)(using $c)}
    

    but it doesn't. If transparent inline def tupled[F](f: F): Any = ... then '{ tupled[F]($f) } doesn't match '{...; $x: (`t` => y)}. If transparent inline def tupled[F](f: F): Function1[?, ?] = ... then t is Nothing, y is Any.

    So let's make tupled an implicit macro (type class) in order to control better the return type of tupled

    import scala.quoted.*
    
    trait Tupled[F]:
      type Out
      def apply(f: F): Out
    
    object Tupled:
      type Aux[F, Out0] = Tupled[F] { type Out = Out0 }
      def instance[F, Out0](g: F => Out0): Aux[F, Out0] = new Tupled[F]:
        type Out = Out0
        def apply(f: F): Out = g(f)
    
      transparent inline given [F]: Tupled[F] = ${mkTupledImpl[F]}
    
      def mkTupledImpl[F: Type](using Quotes): Expr[Tupled[F]] =
        import quotes.reflect.*
        val allTypeArgs = TypeRepr.of[F].typeArgs
        val argTypes    = allTypeArgs.init
        val argCount    = argTypes.length
        val returnType  = allTypeArgs.last
    
        val tupleType = AppliedType(
          TypeRepr.typeConstructorOf(Class.forName(s"scala.Tuple$argCount")),
          argTypes
        )
    
        (tupleType.asType, returnType.asType) match
          case ('[t], '[b]) => '{
            instance[F, t => b]((f: F) => (a: t) => ${
              Apply(
                Select.unique('f.asTerm, "apply"),
                (1 to argCount).toList.map(i => Select.unique('a.asTerm, s"_$i"))
              ).asExprOf[b]
            })
          }
    
    def tupled[F](f: F)(using tupledInst: Tupled[F]): tupledInst.Out = tupledInst(f)
    
    // for match-type implementation of curry
    
    inline def curry1[F, T <: Tuple, Y](f: F)(using
      tupledInst: Tupled[F],
      ev: tupledInst.Out <:< (T => Y),
    ): Curry[T, Y] = curry(tupled(f))
    

    Testing:

    val f: (Int, String, Boolean) => String = (i, s, b) => s"$i, $s, $b"
    val g = curry1(f)
    g : (Int => String => Boolean => String) // checking the type
    g(1)("a")(true) // 1, a, true
    

    Alternatively to tupled, you can try built-in type class scala.util.TupledFunction https://docs.scala-lang.org/scala3/reference/experimental/tupled-function.html (thanks to @MartinHH for pointing this out in the comments)

    // for match-type implementation of curry
    
    inline def curry1[F, T <: Tuple, Y](f: F)(using
      tf: TupledFunction[F, T => Y]
    ): Curry[T, Y] = curry(tf.tupled(f))
    
    // for type-class implementation of curry
    
    def curry1[F, T <: Tuple, Y](f: F)(using
      tf: TupledFunction[F, T => Y],
      c: Curry[T, Y]
    ): c.Out = curry(tf.tupled(f))
    

    TupledFunction is similar to type classes shapeless.ops.function.{FnToProduct, FnFromProduct} in Scala 2

    https://github.com/milessabin/shapeless/wiki/Feature-overview:-shapeless-2.0.0#facilities-for-abstracting-over-arity

    Partial function application in Scala for arbitrary input arguments

    Function taking another function of arbitrary arity as argument

    Scala's type system and the input to FunctionN