Search code examples
scalatypestype-systems

Returning the same type the function was passed


I have the following code implementation of Breadth-First search.

trait State{
   def successors:Seq[State]
   def isSuccess:Boolean = false
   def admissableHeuristic:Double
}
def breadthFirstSearch(initial:State):Option[List[State]] = {
   val open= new scala.collection.mutable.Queue[List[State]]
   val closed = new scala.collection.mutable.HashSet[State]
   open.enqueue(initial::Nil)
   while (!open.isEmpty){
      val path:List[State]=open.dequeue()
      if(path.head.isSuccess) return Some(path.reverse)
      closed += path.head
      for (x <- path.head.successors)
        if (!closed.contains(x))
          open.enqueue(x::path)
   }

   return None
}

If I define a subtype of State for my particular problem

class CannibalsState extends State {
 //...
}

What's the best way to make breadthFirstSearch return the same subtype as it was passed?

Supposing I change this so that there are 3 different state classes for my particular problem and they share a common supertype:

abstract class CannibalsState extends State {
 //...
}
class LeftSideOfRiver extends CannibalsState {
 //...
}
class InTransit extends CannibalsState {
 //...
}
class RightSideOfRiver extends CannibalsState {
 //...
}

How can I make the types work out so that breadthFirstSearch infers that the correct return type is CannibalsState when it's passed an instance of LeftSideOfRiver?

Can this be done with an abstract type member, or must it be done with generics?


Solution

  • One option is to use generics as Randall described. If you want to achieve something similar with an abstract type member, then you can do it like this (based on Mitch's code):

    trait ProblemType {
    
        type S <: State
    
        trait State {
            def successors: Seq[S]
            def isSuccess: Boolean = false
            def admissableHeuristic: Double
        }
    
        def breadthFirstSearch(initial: S): Option[List[S]] = {
            val open = new scala.collection.mutable.Queue[List[S]]
            val closed = new scala.collection.mutable.HashSet[S]
            open.enqueue(initial :: Nil)
            while (!open.isEmpty) {
                val path: List[S] = open.dequeue()
                if (path.head.isSuccess) return Some(path.reverse)
                closed += path.head
                for (x <- path.head.successors)
                    if (!closed.contains(x))
                        open.enqueue(x :: path)
            }
    
            return None
        }
    
    }
    
    object RiverCrossingProblem extends ProblemType {
    
        type S = CannibalsState
    
        abstract class CannibalsState extends State {
         //...
        }
        class LeftSideOfRiver extends CannibalsState {
         //...
        }
        class InTransit extends CannibalsState {
         //...
        }
        class RightSideOfRiver extends CannibalsState {
         //...
        }
    
    }