Search code examples
scalauser-defined-typesscala-macrosnewtype

Do I need to define functions when using newtype in Scala?


I'm trying to learn about types more and use them to help avoid silly errors.

I asked a similar question here, but ran into problems that are similar to what I show here (I think), and ultimately abandoned the answer. A comment to that question suggested I look into newtype.

It seems simpler on the surface, but I'm still left scratching my head.

I have this code:

  package com.craigtreptow.scrayz
  import io.estatico.newtype.macros.newtype

    package object Colors {

    def multiply(c: Color, n: Double): Color = {
      Color(
        c.red   * n,
        c.green * n,
        c.blue  * n
       )
    }

    @newtype case class Red(toDouble: Double)
    case class Color(red: Red, green: Double, blue: Double)
}

The above produces these errors:

[info] Compiling 1 Scala source to /Users/Ctreptow/code/scrayz/target/scala-2.13/classes ...
[error] /Users/Ctreptow/code/scrayz/src/main/scala/com/craigtreptow/scrayz/Colors/package.scala:11:15: value * is not a member of com.craigtreptow.scrayz.Colors.package.Red
[error]       c.red   * n,
[error]               ^
[error] one error found
[error] (Compile / compileIncremental) Compilation failed
[error] Total time: 4 s, completed Apr 30, 2020 3:19:36 PM

I think I should be able to automatically derive the * since this new type is ultimately a Double.

Can I? If not, how do I define my own function with a param? e.g. *


Solution

  • You can define * manually

    @newtype case class Red(toDouble: Double) {
      def *(n: Double): Red = Red(toDouble * n)
    }
    
       // scalacOptions += "-Ymacro-debug-lite"
    //Warning:scalac: {
    //  type Red = Red.Type;
    //  object Red extends scala.AnyRef {
    //    def <init>() = {
    //      super.<init>();
    //      ()
    //    };
    //    def apply(toDouble: Double): Red = toDouble.asInstanceOf[Red];
    //    final implicit class Ops$newtype extends AnyVal {
    //      <paramaccessor> val $this$: Type = _;
    //      def <init>($this$: Type) = {
    //        super.<init>();
    //        ()
    //      };
    //      def toDouble: Double = $this$.asInstanceOf[Double];
    //      def $times(n: Double): Red = Red(toDouble.$times(n))
    //    };
    //    implicit def opsThis(x: Ops$newtype): Type = x.$this$;
    //    @new _root_.scala.inline() implicit def unsafeWrap: Coercible[Repr, Type] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def unsafeUnwrap: Coercible[Type, Repr] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def unsafeWrapM[M[_]]: Coercible[M[Repr], M[Type]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def unsafeUnwrapM[M[_]]: Coercible[M[Type], M[Repr]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def cannotWrapArrayAmbiguous1: Coercible[_root_.scala.Array[Repr], _root_.scala.Array[Type]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def cannotWrapArrayAmbiguous2: Coercible[_root_.scala.Array[Repr], _root_.scala.Array[Type]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def cannotUnwrapArrayAmbiguous1: Coercible[_root_.scala.Array[Type], _root_.scala.Array[Repr]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def cannotUnwrapArrayAmbiguous2: Coercible[_root_.scala.Array[Type], _root_.scala.Array[Repr]] = Coercible.instance;
    //    def deriving[TC[_]](implicit ev: TC[Repr]): TC[Type] = ev.asInstanceOf[TC[Type]];
    //    type Repr = Double;
    //    type Base = _root_.scala.Any {
    //      type __Red__newtype
    //    };
    //    abstract trait Tag extends _root_.scala.Any;
    //    type Type <: Base with Tag
    //  };
    //  ()
    //}
    

    If you want to derive all methods automatically (using corresponding methods of Double), normally this can be done with scala.Dynamic + macro

    import scala.language.dynamics
    import scala.language.experimental.macros
    import scala.reflect.macros.whitebox
    
    case class Red(toDouble: Double) extends Dynamic {
      def applyDynamic(method: String)(args: Any*): Any = macro Macro.impl
    }
    
    object Macro {
      def impl(c: whitebox.Context)(method: c.Tree)(args: c.Tree*): c.Tree = {
        import c.universe._
        val q"${methodName: String}" = method
        q"${c.prefix}.toDouble.${TermName(methodName).encodedName.toTermName}(..$args)"
      }
    }
    
    object Colors {
      val c: Color = ???
      val n: Double = ???
      c.red * n
    }
    
    //Warning:scalac: performing macro expansion Colors.this.c.red.applyDynamic("*")(Colors.this.n) ...
    //Warning:scalac: Colors.this.c.red.toDouble.$times(Colors.this.n)
    

    but unfortunately this will not work with @newtype

    @newtype case class Red(toDouble: Double) extends Dynamic {
      def applyDynamic(method: String)(args: Any*): Any = macro Macro.impl
    }
    //Error: newtypes do not support inheritance; illegal supertypes: Dynamic
    

    So you can define one more macro annotation @exportMethods

    import scala.annotation.{StaticAnnotation, compileTimeOnly}
    import scala.language.experimental.macros
    import scala.reflect.macros.blackbox
    
    @compileTimeOnly("enable macro paradise")
    class exportMethods extends StaticAnnotation {
      def macroTransform(annottees: Any*): Any = macro ExportMethodsMacro.impl
    }
    
    object ExportMethodsMacro {
      def impl(c: blackbox.Context)(annottees: c.Tree*): c.Tree = {
        import c.universe._
        annottees match {
          case q"$mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self => ..$stats }" :: tail =>
            val exportedMethods = paramss match {
              case (q"$_ val $paramName: $paramType = $_" :: Nil) :: Nil =>
                val paramTyp = c.typecheck(tq"$paramType", mode = c.TYPEmode).tpe
                paramTyp.decls.map(_.asMethod)
                  .filterNot(s => Set(paramName, TermName("getClass"), TermName("<init>")).contains(s.name))
                  .map(s => {
                    val paramss1 = s.paramLists.map(_.map(s => q"val ${s.name.toTermName}: ${s.typeSignature}"))
                    val paramss2 = s.paramLists.map(_.map(s => q"${s.name.toTermName}"))
                    if (s.returnType =:= paramTyp)
                      q"def ${s.name}(...$paramss1): $tpname = ${tpname.toTermName}.apply($paramName.${s.name}(...$paramss2))"
                    else
                      q"def ${s.name}(...$paramss1): ${s.returnType} = $paramName.${s.name}(...$paramss2)"
                  })
              case _ => c.abort(c.enclosingPosition, "class must have single parameter")
            }
            q"""
              $mods class $tpname[..$tparams] $ctorMods(...$paramss) extends { ..$earlydefns } with ..$parents { $self =>
                ..$stats
                ..$exportedMethods
              }
              ..$tail
            """
        }
      }
    }
    

    and use it with @newtype

    import io.estatico.newtype.macros.newtype
    
    @exportMethods @newtype case class Red(toDouble: Double)
    
    //Warning:scalac: {
    //  type Red = Red.Type;
    //  object Red extends scala.AnyRef {
    //    def <init>() = {
    //      super.<init>();
    //      ()
    //    };
    //    def apply(toDouble: Double): Red = toDouble.asInstanceOf[Red];
    //    final implicit class Ops$newtype extends AnyVal {
    //      <paramaccessor> val $this$: Type = _;
    //      def <init>($this$: Type) = {
    //        super.<init>();
    //        ()
    //      };
    //      def toDouble: Double = $this$.asInstanceOf[Double];
    //      def toByte: Byte = toDouble.toByte;
    //      def toShort: Short = toDouble.toShort;
    //      def toChar: Char = toDouble.toChar;
    //      def toInt: Int = toDouble.toInt;
    //      def toLong: Long = toDouble.toLong;
    //      def toFloat: Float = toDouble.toFloat;
    //      def unary_$plus: Red = Red.apply(toDouble.unary_$plus);
    //      def unary_$minus: Red = Red.apply(toDouble.unary_$minus);
    //      def $plus(x: String): String = toDouble.$plus(x);
    //      def $eq$eq(x: Byte): Boolean = toDouble.$eq$eq(x);
    //      def $eq$eq(x: Short): Boolean = toDouble.$eq$eq(x);
    //      def $eq$eq(x: Char): Boolean = toDouble.$eq$eq(x);
    //      def $eq$eq(x: Int): Boolean = toDouble.$eq$eq(x);
    //      def $eq$eq(x: Long): Boolean = toDouble.$eq$eq(x);
    //      def $eq$eq(x: Float): Boolean = toDouble.$eq$eq(x);
    //      def $eq$eq(x: Double): Boolean = toDouble.$eq$eq(x);
    //      def $bang$eq(x: Byte): Boolean = toDouble.$bang$eq(x);
    //      def $bang$eq(x: Short): Boolean = toDouble.$bang$eq(x);
    //      def $bang$eq(x: Char): Boolean = toDouble.$bang$eq(x);
    //      def $bang$eq(x: Int): Boolean = toDouble.$bang$eq(x);
    //      def $bang$eq(x: Long): Boolean = toDouble.$bang$eq(x);
    //      def $bang$eq(x: Float): Boolean = toDouble.$bang$eq(x);
    //      def $bang$eq(x: Double): Boolean = toDouble.$bang$eq(x);
    //      def $less(x: Byte): Boolean = toDouble.$less(x);
    //      def $less(x: Short): Boolean = toDouble.$less(x);
    //      def $less(x: Char): Boolean = toDouble.$less(x);
    //      def $less(x: Int): Boolean = toDouble.$less(x);
    //      def $less(x: Long): Boolean = toDouble.$less(x);
    //      def $less(x: Float): Boolean = toDouble.$less(x);
    //      def $less(x: Double): Boolean = toDouble.$less(x);
    //      def $less$eq(x: Byte): Boolean = toDouble.$less$eq(x);
    //      def $less$eq(x: Short): Boolean = toDouble.$less$eq(x);
    //      def $less$eq(x: Char): Boolean = toDouble.$less$eq(x);
    //      def $less$eq(x: Int): Boolean = toDouble.$less$eq(x);
    //      def $less$eq(x: Long): Boolean = toDouble.$less$eq(x);
    //      def $less$eq(x: Float): Boolean = toDouble.$less$eq(x);
    //      def $less$eq(x: Double): Boolean = toDouble.$less$eq(x);
    //      def $greater(x: Byte): Boolean = toDouble.$greater(x);
    //      def $greater(x: Short): Boolean = toDouble.$greater(x);
    //      def $greater(x: Char): Boolean = toDouble.$greater(x);
    //      def $greater(x: Int): Boolean = toDouble.$greater(x);
    //      def $greater(x: Long): Boolean = toDouble.$greater(x);
    //      def $greater(x: Float): Boolean = toDouble.$greater(x);
    //      def $greater(x: Double): Boolean = toDouble.$greater(x);
    //      def $greater$eq(x: Byte): Boolean = toDouble.$greater$eq(x);
    //      def $greater$eq(x: Short): Boolean = toDouble.$greater$eq(x);
    //      def $greater$eq(x: Char): Boolean = toDouble.$greater$eq(x);
    //      def $greater$eq(x: Int): Boolean = toDouble.$greater$eq(x);
    //      def $greater$eq(x: Long): Boolean = toDouble.$greater$eq(x);
    //      def $greater$eq(x: Float): Boolean = toDouble.$greater$eq(x);
    //      def $greater$eq(x: Double): Boolean = toDouble.$greater$eq(x);
    //      def $plus(x: Byte): Red = Red.apply(toDouble.$plus(x));
    //      def $plus(x: Short): Red = Red.apply(toDouble.$plus(x));
    //      def $plus(x: Char): Red = Red.apply(toDouble.$plus(x));
    //      def $plus(x: Int): Red = Red.apply(toDouble.$plus(x));
    //      def $plus(x: Long): Red = Red.apply(toDouble.$plus(x));
    //      def $plus(x: Float): Red = Red.apply(toDouble.$plus(x));
    //      def $plus(x: Double): Red = Red.apply(toDouble.$plus(x));
    //      def $minus(x: Byte): Red = Red.apply(toDouble.$minus(x));
    //      def $minus(x: Short): Red = Red.apply(toDouble.$minus(x));
    //      def $minus(x: Char): Red = Red.apply(toDouble.$minus(x));
    //      def $minus(x: Int): Red = Red.apply(toDouble.$minus(x));
    //      def $minus(x: Long): Red = Red.apply(toDouble.$minus(x));
    //      def $minus(x: Float): Red = Red.apply(toDouble.$minus(x));
    //      def $minus(x: Double): Red = Red.apply(toDouble.$minus(x));
    //      def $times(x: Byte): Red = Red.apply(toDouble.$times(x));
    //      def $times(x: Short): Red = Red.apply(toDouble.$times(x));
    //      def $times(x: Char): Red = Red.apply(toDouble.$times(x));
    //      def $times(x: Int): Red = Red.apply(toDouble.$times(x));
    //      def $times(x: Long): Red = Red.apply(toDouble.$times(x));
    //      def $times(x: Float): Red = Red.apply(toDouble.$times(x));
    //      def $times(x: Double): Red = Red.apply(toDouble.$times(x));
    //      def $div(x: Byte): Red = Red.apply(toDouble.$div(x));
    //      def $div(x: Short): Red = Red.apply(toDouble.$div(x));
    //      def $div(x: Char): Red = Red.apply(toDouble.$div(x));
    //      def $div(x: Int): Red = Red.apply(toDouble.$div(x));
    //      def $div(x: Long): Red = Red.apply(toDouble.$div(x));
    //      def $div(x: Float): Red = Red.apply(toDouble.$div(x));
    //      def $div(x: Double): Red = Red.apply(toDouble.$div(x));
    //      def $percent(x: Byte): Red = Red.apply(toDouble.$percent(x));
    //      def $percent(x: Short): Red = Red.apply(toDouble.$percent(x));
    //      def $percent(x: Char): Red = Red.apply(toDouble.$percent(x));
    //      def $percent(x: Int): Red = Red.apply(toDouble.$percent(x));
    //      def $percent(x: Long): Red = Red.apply(toDouble.$percent(x));
    //      def $percent(x: Float): Red = Red.apply(toDouble.$percent(x));
    //      def $percent(x: Double): Red = Red.apply(toDouble.$percent(x))
    //    };
    //    implicit def opsThis(x: Ops$newtype): Type = x.$this$;
    //    @new _root_.scala.inline() implicit def unsafeWrap: Coercible[Repr, Type] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def unsafeUnwrap: Coercible[Type, Repr] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def unsafeWrapM[M[_]]: Coercible[M[Repr], M[Type]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def unsafeUnwrapM[M[_]]: Coercible[M[Type], M[Repr]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def cannotWrapArrayAmbiguous1: Coercible[_root_.scala.Array[Repr], _root_.scala.Array[Type]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def cannotWrapArrayAmbiguous2: Coercible[_root_.scala.Array[Repr], _root_.scala.Array[Type]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def cannotUnwrapArrayAmbiguous1: Coercible[_root_.scala.Array[Type], _root_.scala.Array[Repr]] = Coercible.instance;
    //    @new _root_.scala.inline() implicit def cannotUnwrapArrayAmbiguous2: Coercible[_root_.scala.Array[Type], _root_.scala.Array[Repr]] = Coercible.instance;
    //    def deriving[TC[_]](implicit ev: TC[Repr]): TC[Type] = ev.asInstanceOf[TC[Type]];
    //    type Repr = Double;
    //    type Base = _root_.scala.Any {
    //      type __Red__newtype
    //    };
    //    abstract trait Tag extends _root_.scala.Any;
    //    type Type <: Base with Tag
    //  };
    //  ()
    //}
    

    Testing:

    multiply(Color(Red(1.0), 2.0, 3.0), 4.0) //Color(4.0,8.0,12.0)
    

    Order of annotations is significant (firstly @exportMethods is expanded, secondly @newtype is expanded).