Search code examples
scalashapeless

How to flat nest tuple parameter in function?


I have a function g which need parameter (Int, (Int, Int)) => Int, and a flat function f0 (Int, Int, Int) => Int

I want to construct a function ft which can flat parameters of g to f0. Here is the example:

val f0: ((Int, Int, Int)) => Int = (x: (Int, Int, Int)) => {
  x._1 + x._2 + x._3
}

def g(f: ((Int, (Int, Int))) => Int): Int = f(1,(2,3))

def ft(f: ((Int, Int, Int)) => Int): ((Int, (Int, Int))) => Int = (p: (Int, (Int, Int)))  => {
  f(p._1, p._2._1, p._2._2)
  
}

// invoke it 
g(ft(f0))

But I have several functions of nested tuples, and I don't want to transform each manually. For example, ((Int, Int), (Int, Int)) => Int to (Int, Int, Int, Int) => Int

Here said it could use shapeless Then the new function would like


import shapeless._
import ops.tuple.FlatMapper

trait LowPriorityFlatten extends Poly1 {
  implicit def default[T] = at[T](Tuple1(_))
}

object flatten extends LowPriorityFlatten {
  implicit def caseTuple[P <: Product](implicit lfm: Lazy[FlatMapper[P, flatten.type]]) =
    at[P](lfm.value(_))
}

def ft(f: ((Int, Int, Int)) => Int): ((Int, (Int, Int))) => Int = (p: (Int, (Int, Int)))  => {
  val a: (Int, Int, Int) = flatten(p).asInstanceOf[(Int, Int, Int)]
  f(a)
}

Code above has two problem:

  1. how to define function ft[A, B, C](f: A => C): B where A is a flatten type of B ?
  2. flatten(p) will product type FlatMapper.this.Out and miss the type, so I use asInstanceOf to cast type here.

So, How to write a function to flatten any kind of nested tuple in a parameter?


Solution

  • The following code works in Scala 3:

    scala> type Flat[T <: Tuple] <: Tuple = T match
         |   case EmptyTuple => EmptyTuple
         |   case h *: t => h match
         |     case Tuple => Tuple.Concat[Flat[h], Flat[t]]
         |     case _     => h *: Flat[t]
         |
    
    scala> def flat[T <: Tuple](v: T): Flat[T] = (v match
         |   case e: EmptyTuple => e
         |   case h *: ts => h match
         |     case t: Tuple => flat(t) ++ flat(ts)
         |     case _        => h *: flat(ts)).asInstanceOf[Flat[T]]
    def flat[T <: Tuple](v: T): Flat[T]
    
    scala> def ft[A <: Tuple, C](f: Flat[A] => C): A => C = a => f(flat(a))
    def ft[A <: Tuple, C](f: Flat[A] => C): A => C
    
    scala> val f0: ((Int, Int, Int)) => Int = x => x._1 + x._2 + x._3
    
    
    scala> def g0(f: ((Int, (Int, Int))) => Int): Int = f(1,(2,3))
    
    scala> g0(ft(f0))
    val res0: Int = 6
    

    Edit: Add scala2's version:

      import shapeless._
      import ops.tuple.FlatMapper
      import syntax.std.tuple._
    
      trait LowPriorityFlat extends Poly1 {
        implicit def default[T] = at[T](Tuple1(_))
      }
      object Flat extends LowPriorityFlat {
        implicit def caseTuple[P <: Product](implicit fm: FlatMapper[P, Flat.type]) =
          at[P](_.flatMap(Flat))
      }
    
      type F[A, B] = FlatMapper.Aux[A, Flat.type, B]
      
      def flatTup[T <: Product](t: T)(implicit lfm: FlatMapper[T, Flat.type]): lfm.Out = 
        FlatMapper[T, Flat.type].apply(t)
    
      def flatFun[A <: Product, B <: Product, C](f: B => C)
                                                (implicit lfm: F[A, B]): A => C =
          a => f(flatTup(a))
      
      val f0: ((Int, Double, Int, Double)) => Double = { case(i1, d1, i2, d2) => (i1 + i2) / (d1 + d2) }
      def g0(f: (((Int, Double), (Int, Double))) => Double): Double = f((1, 2.0), (3, 4.0))
      val r0 = g0(flatFun(f0))