Search code examples
scalafunctional-programmingscala-catszio

Scala with cats exercise Data validation on Kleisli, why my code fails fast instead of accumulating errors?


I'm reading the scala-with-cats book and follow it's exercise. When I come to the case study: data validation, I encounter some problems.

Here is my entire code (just the same with the book):

package org.scala.ch10.final_recap

import cats.Semigroup
import cats.data.Validated
import cats.data.Validated._
import cats.data.Kleisli
import cats.data.NonEmptyList
import cats.instances.either._
import cats.syntax.apply._
import cats.syntax.semigroup._
import cats.syntax.validated._

sealed trait Predicate[E, A] {
  import Predicate._
  def and(that: Predicate[E, A]): Predicate[E, A] =
    And(this, that)

  def or(that: Predicate[E, A]): Predicate[E, A] =
    Or(this, that)

  /**
   * This part is for Kleislis
   * @return
   */
  def run(implicit s: Semigroup[E]): A => Either[E, A] =
    (a: A) => this(a).toEither

  def apply(a: A)(implicit s: Semigroup[E]): Validated[E, A] =
    this match {
      case Pure(func) =>
        func(a)
      case And(left, right) => (left(a), right(a)).mapN((_, _) => a)
      case Or(left, right) =>
        left(a) match {
          case Valid(_) => Valid(a)
          case Invalid(e1) =>
            right(a) match {
              case Valid(_) => Invalid(e1)
              case Invalid(e2) => Invalid(e1 |+| e2)
            }
        }
    }
}

object Predicate {
  final case class And[E, A](left: Predicate[E, A], right: Predicate[E, A]) extends Predicate[E, A]
  final case class Or[E, A](left: Predicate[E, A], right: Predicate[E, A]) extends Predicate[E, A]
  final case class Pure[E, A](func: A => Validated[E, A]) extends Predicate[E, A]
  def apply[E, A](f: A => Validated[E, A]): Predicate[E, A] = Pure(f)
  def lift[E, A](err: E, fn: A => Boolean): Predicate[E, A] = Pure(a => if(fn(a)) a.valid else err.invalid)
}

object FinalRecapPredicate {
  type Errors = NonEmptyList[String]
  def error(s: String): NonEmptyList[String] = NonEmptyList(s, Nil)
  type Result[A] = Either[Errors, A]
  type Check[A, B] = Kleisli[Result, A, B]
  def check[A, B](func: A => Result[B]): Check[A, B] = Kleisli(func)
  def checkPred[A](pred: Predicate[Errors, A]): Check[A, A] =
    Kleisli[Result, A, A](pred.run)

  def longerThan(n: Int): Predicate[Errors, String] =
    Predicate.lift(
      error(s"Must be longer than $n characters"),
      str => str.length > n
    )

  val alphanumeric: Predicate[Errors, String] =
    Predicate.lift(
      error(s"Must be all alphanumeric characters"),
      str => str.forall(_.isLetterOrDigit)
    )

  def contains(char: Char): Predicate[Errors, String] =
    Predicate.lift(
      error(s"Must contain the character $char"),
      str => str.contains(char)
    )

  def containsOnce(char: Char): Predicate[Errors, String] =
    Predicate.lift(
      error(s"Must contain the character $char only once"),
      str => str.count(_ == char) == 1
    )

  val checkUsername: Check[String, String] = checkPred(longerThan(3) and alphanumeric)

  val splitEmail: Check[String, (String, String)] = check(_.split('@') match {
    case Array(name, domain) =>
      Right((name, domain))
    case _ =>
      Left(error("Must contain a single @ character"))
  })

  val checkLeft: Check[String, String] = checkPred(longerThan(0))

  val checkRight: Check[String, String] = checkPred(longerThan(3) and contains('.'))

  val joinEmail: Check[(String, String), String] =
    check {
      case (l, r) => (checkLeft(l), checkRight(r)).mapN(_ + "@" + _)
    }

  val checkEmail: Check[String, String] = splitEmail andThen joinEmail

  final case class User(username: String, email: String)

  def createUser(username: String, email: String): Either[Errors, User] =
    (checkUsername.run(username),
      checkEmail.run(email)).mapN(User)

  def main(args: Array[String]): Unit = {
    println(createUser("", "[email protected]@io"))
  }
}

It supposes the code should end up with the error message Left(NonEmptyList(Must be longer than 3 characters), Must contain a single @ character) But what I actually is Left(NonEmptyList(Must be longer than 3 characters))

Obviously, it does not work as expected. It fails fast instead of accumulating errors... How to fix that plz? (I've spent hours now and can't get a workaround)


Solution

  • This is the "problematic" part:

    def createUser(username: String, email: String): Either[Errors, User] =
      (checkUsername.run(username),
        checkEmail.run(email)).mapN(User)
    

    You are combining a tuple of Results, where

    type Result[A] = Either[Errors, A]
    

    This means you are really doing a mapN on a pair of Eithers, an operation provided by the Semigroupal type class. This operation will not accumulate results.

    There are several reasons for this, but one that I find particularly important is the preserving of behaviour if we find ourselves using a Semigroupal / Applicative which also happens to be a Monad. Why is that a problem? Because Monads are sequencing operations, making each step depend on the previous one, and having "fail early" semantics. When using some Monad, one might expect those semantics to be preserved when using constructs from the underlying Applicative (every Monad is also an Applicative). In that case, if the implementation of Applicative used "accumulation" semantics instead of "fail early" semantics, we would ruin some important laws like referential transparency.

    You can use a parallel version of mapN, called parMapN, whose contract guarantees that the implementation will be evaluating all results in parallel. This means that it definitely cannot be expected to have the "fail early" semantics, and accumulating results is fine in that case.

    Note that Validated accumulates results as well, usually in a NonEmptyList or a NonEmptyChain. This is probably why you expected to see your accumulated results; the only problem is, you were not using Validated values in the "problematic" part of your code, but raw Eithers instead.

    Here's some simple code that demonstrates the above concepts:

    import cats.data._
    import cats.implicits._
    
    val l1: Either[String, Int] = Left("foo")
    val l2: Either[String, Int] = Left("bar")
    
    (l1, l2).mapN(_ + _) 
    // Left(foo)
    
    (l1, l2).parMapN(_ + _) 
    // Left(foobar)
    
    val v1: ValidatedNel[String, Int] = l1.toValidatedNel
    val v2: ValidatedNel[String, Int] = l2.toValidatedNel
    
    (v1, v2).mapN(_ + _) 
    // Invalid(NonEmptyList(foo, bar))