Search code examples
scalatraitsscala-3

How to let the implementor of a trait define the context parameter type


I would like to define a trait in a library, and where the type of the contextual parameter is defined by the caller user of the library. For example, it would look something like:

////////////////////////
// library
////////////////////////
trait Animal:
    // goal: the type of the context we pass to `make_sound` is defined by the implementor
    def make_sound()(using ctx: ???): Unit


def do_something(animal: Animal) = {
    animal.make_sound()
}

////////////////////////
// user code
////////////////////////
trait CatMakeSoundContext:
    def get_claw_length(): Int

class Cat extends Animal:
    // FAILS because this doesn't have the same signature as in `Animal`
    override def make_sound()(using ctx: CatMakeSoundContext) = {
        if ctx.get_claw_length() > 2 {
            println("scratch!")
        } else {
            println("meow")
        }
    }

class MyCatMakeSoundContext extends CatMakeSoundContext {
    override def get_claw_length(): Int = 2
}

def call_example() = {
    val cat: Cat = ...;
    val claw_context = MyCatMakeSoundContext()    
      
    // somehow set up the `claw_context` with a `given` clause
    // to be used by `animal.make_sound()`
 
    // call `do_something()`, which will end up using `claw_context`
    // as the value for the implicit `ctx` in `make_sound()`
    do_something(cat)
}

I'm most likely using the wrong features of the language in this example. Is it possible to achieve something similar to what I described here, in one way or another? I suspect generics are probably useful, but I wasn't able to figure it out.


Solution

  • You can try to make the context a type member

    // library
    trait Animal:
      type Context
      def make_sound()(using ctx: Context): Unit
    
    def do_something(animal: Animal)(using animal.Context): Unit =
      animal.make_sound()
    
    // user code
    trait CatMakeSoundContext:
      def get_claw_length(): Int
    
    class Cat extends Animal:
      override type Context = CatMakeSoundContext
      override def make_sound()(using ctx: Context): Unit =
        if ctx.get_claw_length() > 2
        then println("scratch!")
        else println("meow")
    
    class MyCatMakeSoundContext extends CatMakeSoundContext:
      override def get_claw_length(): Int = 2
    
    def call_example(): Unit =
      val cat: Cat = Cat()
      given CatMakeSoundContext = MyCatMakeSoundContext()
      do_something(cat) // meow
    

    Or you can make the context a type parameter

    // library
    trait Animal[Context]:
      def make_sound()(using ctx: Context): Unit
    
    def do_something[C](animal: Animal[C])(using C): Unit =
      animal.make_sound()
    
    // user code
    trait CatMakeSoundContext:
      def get_claw_length(): Int
    
    class Cat extends Animal[CatMakeSoundContext]:
      override def make_sound()(using ctx: CatMakeSoundContext): Unit =
        if ctx.get_claw_length() > 2
        then println("scratch!")
        else println("meow")
    
    class MyCatMakeSoundContext extends CatMakeSoundContext:
      override def get_claw_length(): Int = 2
    
    def call_example(): Unit =
      val cat: Cat = Cat()
      given CatMakeSoundContext = MyCatMakeSoundContext()
      do_something(cat) // meow
    

    Additionally you can make Animal a type class

    // library
    trait Animal[T]:
      type Context
      def make_sound()(using ctx: Context): Unit
    
    def do_something[T](t: T)(using animal: Animal[T], ctx: animal.Context): Unit =
      animal.make_sound()
    
    // user code
    trait CatMakeSoundContext:
      def get_claw_length(): Int
    
    class Cat
    given Animal[Cat] with
      override type Context = CatMakeSoundContext
      override def make_sound()(using ctx: Context): Unit =
        if ctx.get_claw_length() > 2
        then println("scratch!")
        else println("meow")
    
    class MyCatMakeSoundContext extends CatMakeSoundContext:
      override def get_claw_length(): Int = 2
    
    def call_example(): Unit =
      val cat: Cat = Cat()
      given CatMakeSoundContext = MyCatMakeSoundContext()
      do_something(cat) // meow