0

I am new to cats-effect and I am trying to implement the classical expression evaluation using cats-effect. Using eval I would like to return an IO[Double] instead of Double. I have my naive code below but of course it doesnt type check. What is the right way to approach this? (It seems like generally with pattern matching it is difficult with IOs).

import cats.effect._
import cats.effect.unsafe.implicits.global

sealed trait Expression
case class Add(x: Expression, y: Expression) extends Expression
case class Mult(x: Expression, y: Expression) extends Expression
case class Exp(x: Expression) extends Expression
case class Const(x: Double) extends Expression

extension (exp: Expression)
    def +(other: Expression) = Add(exp,other)
    def *(other: Expression) = Mult(exp,other) 

def eval(exp: Expression): IO[Double] = IO{
    exp match
    case Add(x, y) => eval(x) + eval(y) // This does not type check
    case Mult(x, y) => eval(x) * eval(y)
    case Exp(x) => scala.math.exp(eval(x))
    case Const(x) => x
}

val expression1 = Exp((Const(1) + Const(2)) * Const(9))

@main def main = 
    println(eval(expression1).unsafeRunSync())

1 Answers1

3

IO is a monad. Try for-comprehensions

def eval(exp: Expression): IO[Double] =
  exp match
    case Add(x, y)  => for {
      x1 <- eval(x)
      y1 <- eval(y)
    } yield x1 + y1
    case Mult(x, y) => for {
      x1 <- eval(x)
      y1 <- eval(y)
    } yield x1 * y1
    case Exp(x)     => for {
      x1 <- eval(x)
    } yield scala.math.exp(x1)
    case Const(x)   => IO(x)

or applicative syntax

import cats.syntax.apply.given

def eval(exp: Expression): IO[Double] =
  exp match
    case Add(x, y)  => (eval(x), eval(y)).mapN(_ + _)
    case Mult(x, y) => (eval(x), eval(y)).mapN(_ * _)
    case Exp(x)     => eval(x).map(scala.math.exp)
    case Const(x)   => IO(x)

or to define an instance of the type class Numeric

import cats.syntax.apply.given
import Numeric.Implicits.given

given [A: Numeric]: Numeric[IO[A]] = new Numeric[IO[A]]:
  override def plus(x: IO[A], y: IO[A]): IO[A]  = (x, y).mapN(_ + _)
  override def times(x: IO[A], y: IO[A]): IO[A] = (x, y).mapN(_ * _)
  override def minus(x: IO[A], y: IO[A]): IO[A] = ???
  override def negate(x: IO[A]): IO[A] = ???
  override def fromInt(x: Int): IO[A]  = ???
  override def parseString(str: String): Option[IO[A]] = ???
  override def toInt(x: IO[A]): Int       = ???
  override def toLong(x: IO[A]): Long     = ???
  override def toFloat(x: IO[A]): Float   = ???
  override def toDouble(x: IO[A]): Double = ???
  override def compare(x: IO[A], y: IO[A]): Int = ???

def eval(exp: Expression): IO[Double] =
  exp match
    case Add(x, y)  => eval(x) + eval(y)
    case Mult(x, y) => eval(x) * eval(y)
    case Exp(x)     => eval(x).map(scala.math.exp)
    case Const(x)   => IO(x)

or to define your own syntax

import cats.syntax.apply.given
import Numeric.Implicits.given

extension [A: Numeric](x: IO[A])
  def +(y: IO[A]): IO[A] = (x, y).mapN(_ + _)
  def *(y: IO[A]): IO[A] = (x, y).mapN(_ * _)

def exp(x: IO[Double]): IO[Double] = x.map(scala.math.exp)

def eval(expr: Expression): IO[Double] =
  expr match
    case Add(x, y)  => eval(x) + eval(y)
    case Mult(x, y) => eval(x) * eval(y)
    case Exp(x)     => exp(eval(x))
    case Const(x)   => IO(x)

or just

def eval(expr: Expression): IO[Double] =
  def eval0(expr: Expression): Double =
    expr match
      case Add(x, y)  => eval0(x) + eval0(y)
      case Mult(x, y) => eval0(x) * eval0(y)
      case Exp(x)     => scala.math.exp(eval0(x))
      case Const(x)   => x

  IO(eval0(expr))
end eval
Dmytro Mitin
  • 48,194
  • 3
  • 28
  • 66
  • 1
    I'd prefer to use `IO.defer { ... }` or `IO.pure(expr).flatMap { ... }` immediately in `def eval` to make sure that _every_ part of the computation is in IO. Otherwise it might be easy to accidentally make stack unsafe implementation, and you won't figure out where without some careful analysis. – Mateusz Kubuszok Apr 03 '23 at 18:39