Search code examples
scalascala-catsio-monadcats-effect

Expression Evaluation (Add, Mult,etc) with Cats-Effect


I am new to cats-effect and I am trying to implement the classical expression evaluation using cats-effect. Using eval I would like to return an IO[Double] instead of Double. I have my naive code below but of course it doesnt type check. What is the right way to approach this? (It seems like generally with pattern matching it is difficult with IOs).

import cats.effect._
import cats.effect.unsafe.implicits.global

sealed trait Expression
case class Add(x: Expression, y: Expression) extends Expression
case class Mult(x: Expression, y: Expression) extends Expression
case class Exp(x: Expression) extends Expression
case class Const(x: Double) extends Expression

extension (exp: Expression)
    def +(other: Expression) = Add(exp,other)
    def *(other: Expression) = Mult(exp,other) 

def eval(exp: Expression): IO[Double] = IO{
    exp match
    case Add(x, y) => eval(x) + eval(y) // This does not type check
    case Mult(x, y) => eval(x) * eval(y)
    case Exp(x) => scala.math.exp(eval(x))
    case Const(x) => x
}

val expression1 = Exp((Const(1) + Const(2)) * Const(9))

@main def main = 
    println(eval(expression1).unsafeRunSync())

Solution

  • IO is a monad. Try for-comprehensions

    def eval(exp: Expression): IO[Double] =
      exp match
        case Add(x, y)  => for {
          x1 <- eval(x)
          y1 <- eval(y)
        } yield x1 + y1
        case Mult(x, y) => for {
          x1 <- eval(x)
          y1 <- eval(y)
        } yield x1 * y1
        case Exp(x)     => for {
          x1 <- eval(x)
        } yield scala.math.exp(x1)
        case Const(x)   => IO(x)
    

    or applicative syntax

    import cats.syntax.apply.given
    
    def eval(exp: Expression): IO[Double] =
      exp match
        case Add(x, y)  => (eval(x), eval(y)).mapN(_ + _)
        case Mult(x, y) => (eval(x), eval(y)).mapN(_ * _)
        case Exp(x)     => eval(x).map(scala.math.exp)
        case Const(x)   => IO(x)
    

    or to define an instance of the type class Numeric

    import cats.syntax.apply.given
    import Numeric.Implicits.given
    
    given [A: Numeric]: Numeric[IO[A]] = new Numeric[IO[A]]:
      override def plus(x: IO[A], y: IO[A]): IO[A]  = (x, y).mapN(_ + _)
      override def times(x: IO[A], y: IO[A]): IO[A] = (x, y).mapN(_ * _)
      override def minus(x: IO[A], y: IO[A]): IO[A] = ???
      override def negate(x: IO[A]): IO[A] = ???
      override def fromInt(x: Int): IO[A]  = ???
      override def parseString(str: String): Option[IO[A]] = ???
      override def toInt(x: IO[A]): Int       = ???
      override def toLong(x: IO[A]): Long     = ???
      override def toFloat(x: IO[A]): Float   = ???
      override def toDouble(x: IO[A]): Double = ???
      override def compare(x: IO[A], y: IO[A]): Int = ???
    
    def eval(exp: Expression): IO[Double] =
      exp match
        case Add(x, y)  => eval(x) + eval(y)
        case Mult(x, y) => eval(x) * eval(y)
        case Exp(x)     => eval(x).map(scala.math.exp)
        case Const(x)   => IO(x)
    

    or to define your own syntax

    import cats.syntax.apply.given
    import Numeric.Implicits.given
    
    extension [A: Numeric](x: IO[A])
      def +(y: IO[A]): IO[A] = (x, y).mapN(_ + _)
      def *(y: IO[A]): IO[A] = (x, y).mapN(_ * _)
    
    def exp(x: IO[Double]): IO[Double] = x.map(scala.math.exp)
    
    def eval(expr: Expression): IO[Double] =
      expr match
        case Add(x, y)  => eval(x) + eval(y)
        case Mult(x, y) => eval(x) * eval(y)
        case Exp(x)     => exp(eval(x))
        case Const(x)   => IO(x)
    

    or just

    def eval(expr: Expression): IO[Double] =
      def eval0(expr: Expression): Double =
        expr match
          case Add(x, y)  => eval0(x) + eval0(y)
          case Mult(x, y) => eval0(x) * eval0(y)
          case Exp(x)     => scala.math.exp(eval0(x))
          case Const(x)   => x
    
      IO(eval0(expr))
    end eval