Search code examples
scalafunctional-programmingmonadsscala-3

How to write generic Monad law tests?


Given the definitions:

trait Functor[F[_]]:
  extension [A](fa: F[A]) def map[B](f: A => B): F[B]

trait Monad[F[_]] extends Functor[F]:
  def unit[A](a: => A): F[A]

  extension [A](fa: F[A])
    def flatMap[B](f: A => F[B]): F[B]

    def map[B](f: A => B): F[B] =
      flatMap(f.andThen(unit))

object MonadSyntax:
  // 'fA' is the thing on which the method >>= is invoked.
  extension [F[_], A](fA: F[A])(using mA: Monad[F])
    def >>=[B](f: A => F[B]): F[B] = mA.flatMap(fA)(f)

I want to write tests using ScalaTest and ScalaCheck to verify the three Monad laws. Instead of repeating the tests for every Monad instance, I'd like to make the tests generic. My attempt is as follows:

trait MonadLaws[F[_]] { this: AnyFunSpec & ScalaCheckPropertyChecks =>
  // Monad[F].unit(x).flatMap(f) === f(x)
  def leftIdentity[A: Arbitrary, B: Arbitrary](
    using arbAToFB: Arbitrary[A => F[B]], 
    ma: Monad[F]
    ): Unit =
    it("should satisfy the left identity law"):
      forAll { (x: A, f: A => F[B]) =>
        (ma.unit(x) >>= f) shouldBe f(x)
      }

  // m.flatMap(Monad[F].unit) === m
  def rightIdentity[A: Arbitrary](using eqM: Equality[Monad[F]], ma: Monad[F]): Unit =
    it("should satisfy the right identity law"):
        val left = for
          a: A <- ma
        yield summon[Monad[F]].unit(a)

      left shouldBe ma

  // m.flatMap(f).flatMap(g) === m.flatMap { x => f(x).flatMap(g) }
  def associativityLaw[A: Arbitrary, B: Arbitrary, C: Arbitrary](
    using arbAToFB: Arbitrary[A => F[B]],
    arbBToFB: Arbitrary[B => F[C]],
    ma: Monad[F],
    eqM: Equality[Monad[F]]
  ): Unit =
    it("should satisfy the associativity law"):
      forAll { (f: A => F[B], g: B => F[C]) =>
        ma.flatMap(f).flatMap(g) shouldBe m.flatMap(x => f(x).flatMap(g))
      }
}

Then a particular Monad instance would run as tests as below:

class OptionMonadLawsSpec extends AnyFunSpec with ScalaCheckPropertyChecks with MonadLaws[Option]:
  describe("Options form a Monad"):
    import MonadInstances.optionMonad

    leftIdentity[Int, Int]

Imports have been omitted for brevity.

Problem is the second and third tests don't compile, both failing with:

Found:    A => F[A]
Required: F[A²]

where:    A  is a type in method rightIdentity
          A² is a type variable with constraint 
          F  is a type in trait MonadLaws with bounds <: [_] =>> Any

I've racked my brain for more than an hour and can't get the types right. Also, how do I invoke Monad[F].unit where unit is an instance method?


Solution

  • @LuisMiguelMejíaSuárez helped me on Typelevel Discord channel to get the code compiling. I've yet to test it thoroughly, but the types check out, and he was also kind enough to point out a basic misunderstanding.

    trait MonadLaws[F[_]] { this: AnyFunSpec & ScalaCheckPropertyChecks =>
      // Monad[F].unit(x).flatMap(f) === f(x)
      def leftIdentity[A, B](using
          Monad[F],
          Arbitrary[A],
          Arbitrary[A => F[B]],
          Equality[F[B]]
      ): Unit =
        it("should satisfy the left identity law"):
          forAll { (a: A, f: A => F[B]) =>
            val lhs = summon[Monad[F]].unit(a) >>= f
            val rhs = f(a)
    
            lhs shouldBe rhs
          }
    
      // m.flatMap(Monad[F].unit) === m
      def rightIdentity[A, B](using Monad[F], Arbitrary[F[A]], Equality[F[A]]): Unit =
        it("should satisfy the right identity law"):
          forAll { (fa: F[A]) =>
            val lhs = fa >>= summon[Monad[F]].unit
            val rhs = fa
    
            lhs shouldBe rhs
          }
    
      // m.flatMap(f).flatMap(g) === m.flatMap { x => f(x).flatMap(g) }
      def associativityLaw[A, B, C](using
          Monad[F],
          Arbitrary[F[A]],
          Arbitrary[A => F[B]],
          Arbitrary[B => F[C]],
          Equality[F[C]]
      ): Unit =
        it("should satisfy the associativity law"):
          forAll { (fa: F[A], f: A => F[B], g: B => F[C]) =>
            val lhs = fa >>= f >>= g
            val rhs = fa >>= (a => f(a) >>= (g))
    
            lhs shouldBe rhs
          }
    }