Search code examples
scalatypesshapelessscala-macrostype-level-computation

Scala type constraint to check argument values


I'm trying to implement Conway's surreal numbers in Scala. A surreal number is defined recursively – as a pair of sets of surreal numbers, called left and right, such that no element in the right set is less than or equal to any element in the left set. Here the relation "less than or equal to" between surreal numbers is also defined recursively: we say that xy if

  • there is no element a in the left set of x such that ya, and
  • there is no element b in the right set of y such that bx.

We start with defining zero as a pair of empty sets, then use zero to define 1 and -1, and so on.

I cannot figure out how to enforce the definition of a surreal number at compile time. This is what I have now:

case class SurrealNumber(left: Set[SurrealNumber], right: Set[SurrealNumber]) {
  if ((for { a <- left; b <- right; if b <= a } yield (a, b)).nonEmpty)
    throw new Exception
  def <=(other: SurrealNumber): Boolean =
    !this.left.exists(other <= _) && !other.right.exists(_ <= this)
}

val zero = SurrealNumber(Set.empty, Set.empty)
val one = SurrealNumber(Set(zero), Set.empty)
val minusOne = SurrealNumber(Set.empty, Set(zero))

assert(zero <= zero)
assert((zero <= one) && !(one <= zero))
assert((minusOne <= zero) && !(zero <= minusOne))

When the arguments are invalid, as in SurrealNumber(Set(one), Set(zero)), this code would throw a runtime exception. Is it possible to express the validity check as a type constraint, so that SurrealNumber(Set(one), Set(zero)) wouldn't compile?


Solution

  • You could define a macro in order to execute calculations at compile time

    case class SurrealNumber private(left: Set[SurrealNumber], right: Set[SurrealNumber]) {
      def <=(other: SurrealNumber): Boolean =
        !this.left.exists(other <= _) && !other.right.exists(_ <= this)
    }
    
    object SurrealNumber {
      def unsafeApply(left: Set[SurrealNumber], right: Set[SurrealNumber]): SurrealNumber =
        new SurrealNumber(left, right)
    
      def apply(left: Set[SurrealNumber], right: Set[SurrealNumber]): SurrealNumber =
        macro applyImpl
    
      def applyImpl(c: blackbox.Context)(left: c.Tree, right: c.Tree): c.Tree = {
        import c.universe._
        def eval[A](t: Tree): A = c.eval(c.Expr[A](c.untypecheck(t)))
        val l = eval[Set[SurrealNumber]](left)
        val r = eval[Set[SurrealNumber]](right)
        if ((for { a <- l; b <- r; if b <= a } yield (a, b)).nonEmpty)
          c.abort(c.enclosingPosition, "invalid surreal number")
        else q"SurrealNumber.unsafeApply($left, $right)"
      }
    }
    

    but the thing is that although

    SurrealNumber(Set.empty, Set.empty)
    

    is a compile-time value of zero but

    SurrealNumber(Set(zero), Set.empty)
    SurrealNumber(Set.empty, Set(zero))
    

    are runtime values of one, minusOne and compiler doesn't have access to them. So

    SurrealNumber(Set(SurrealNumber(Set.empty, Set.empty)), Set.empty)
    SurrealNumber(Set.empty, Set(SurrealNumber(Set.empty, Set.empty)))
    

    compile but

    SurrealNumber(Set(zero), Set.empty)
    SurrealNumber(Set.empty, Set(zero))
    

    don't.


    So you should redesign SurrealNumber to be more type-level. For example

    import shapeless.{::, HList, HNil, IsDistinctConstraint, OrElse, Poly1, Poly2, Refute, poly}
    import shapeless.ops.hlist.{CollectFirst, LeftReducer}
    import shapeless.test.illTyped
    
    class SurrealNumber[L <: HList : IsDistinctConstraint : IsSorted, 
                        R <: HList : IsDistinctConstraint : IsSorted](implicit
      notExist: Refute[CollectFirst[L, CollectPoly[R]]]
    )
    
    trait LEq[S, S1]
    object LEq {
      implicit def mkLEq[S,  L  <: HList,  R <: HList, 
                         S1, L1 <: HList, R1 <: HList](implicit
        ev:        S  <:< SurrealNumber[L, R],
        ev1:       S1 <:< SurrealNumber[L1, R1],
        notExist:  Refute[CollectFirst[L, FlippedLEqPoly[S1]]],
        notExist1: Refute[CollectFirst[R1, LEqPoly[S]]]
      ): S LEq S1 = null
    }
    
    trait CollectPoly[R <: HList] extends Poly1
    object CollectPoly {
      implicit def cse[R <: HList, LElem](implicit 
        exist: CollectFirst[R, LEqPoly[LElem]]
      ): poly.Case1.Aux[CollectPoly[R], LElem, Unit] = null
    }
    
    trait LEqPoly[FixedElem] extends Poly1
    object LEqPoly {
      implicit def cse[FixedElem, Elem](implicit 
        leq: Elem LEq FixedElem
      ): poly.Case1.Aux[LEqPoly[FixedElem], Elem, Unit] = null
    }
    
    trait FlippedLEqPoly[FixedElem] extends Poly1
    object FlippedLEqPoly {
      implicit def cse[FixedElem, Elem](implicit 
        leq: FixedElem LEq Elem
      ): poly.Case1.Aux[FlippedLEqPoly[FixedElem], Elem, Unit] = null
    }
    
    object isSortedPoly extends Poly2 {
      implicit def cse[Elem, Elem1](implicit 
        leq: Elem LEq Elem1
      ): Case.Aux[Elem, Elem1, Elem1] = null
    }
    type IsSorted[L <: HList] = (L <:< HNil) OrElse LeftReducer[L, isSortedPoly.type]
    
    val zero = new SurrealNumber[HNil, HNil]
    val one = new SurrealNumber[zero.type :: HNil, HNil]
    val minusOne = new SurrealNumber[HNil, zero.type :: HNil]
    illTyped("new SurrealNumber[one.type :: HNil, zero.type :: HNil]")
    new SurrealNumber[zero.type :: HNil, one.type :: HNil]