Search code examples
scalafunctional-programmingfsmalgebraic-data-types

How to narrow the type of an upper-bounded type parameter in a state machine encoding?


Say I have a Cake that can cycle through a number of states:

sealed trait State extends Product with Serializable
object State {
    final case object Raw extends State
    final case class JustRight(temperature: Int) extends State
    final case class Burnt(charCoalContent: Double) extends State
}
final case class Cake[S <: State](name: String, state: S)

This is nice, because now I can make sure that I only try to put Raw cakes into the oven, instead of eating them right away.

But sometimes I just have a Cake[State] lying around and want to try to eat it, but only if it's in an edible state. I could of course always pattern match on cake.state, but I thought it should be possible to save myself a few keystrokes by adding a method def narrow[S <: State]: Cake[State] => Option[Cake[S]].

However, now I'm struggling to actually implement that function. The compiler accepts Try(cake.asInstanceOf[Cake[S]]).toOption, but it seems that would always succeed (I guess because the type parameter is erased, and actually any type A would be accepted here, not just S). What seems to work is Try(cake.copy(state = cake.state.asInstanceOf[S])).toOption, but now I've made a superfluous copy of my data. Is there another better way? Or is that whole encoding maybe flawed from the get-go?


Solution

  • You can solve this problem using a typeclass that checks and casts (in a typesafe way) the type of the state.

    sealed trait State extends Product with Serializable
    object State {
        final case object Raw extends State
        type Raw = Raw.type
        final case class JustRight(temperature: Int) extends State
        final case class Burnt(charCoalContent: Double) extends State
      
        sealed trait Checker[S <: State] {
          def check(state: State): Option[S]
        }
        object Checker {
          private def instance[S <: State](pf: PartialFunction[State, S]): Checker[S] =
            new Checker[S] {
              val f = pf.lift
              override def check(state: State): Option[S] = f(state)
            }
          
          implicit final val RawChecker: Checker[Raw] = instance {
            case Raw => Raw
          }
          
          implicit final val JustRightChecker: Checker[JustRight] = instance {
            case s @ JustRight(_) => s
          }
          
          implicit final val BurntChecker: Checker[Burnt] = instance {
            case s @ Burnt(_) => s
          }
        }
    }
    
    final case class Cake[S <: State](name: String, state: S)
    
    def narrow[S <: State](cake: Cake[State])(implicit checker: State.Checker[S]): Option[Cake[S]] =
      checker.check(cake.state).map(s => cake.copy(state = s))
    

    Which you can use like this:

    val rawCake: Cake[State] = Cake(name = "Foo", state = State.Raw)
    
    narrow[State.Raw](rawCake)
    // res: Option[Cake[State.Raw]] = Some(Cake(Foo,Raw))
    narrow[State.JustRight](rawCake)
    // res: Option[Cake[State.JustRight] = None
    

    BTW if you want to avoid the copy, you may change check to just return Boolean and use a dirty asInstanceOf.

    // Technically speaking it is unsafe, but it seems to work just right.
    def narrowUnsafe[S <: State](cake: Cake[State])(implicit checker: State.Checker[S]): Option[Cake[S]] =
      if (checker.check(cake.state)) Some(cake.asInstanceOf[Cake[S]])
      else None
    

    (You can see the code running here)