Search code examples
pythonscalatranslatetarjans-algorithm

Tarjan algorithm - Python to scala


I'm trying to translate the recursive python code for tarjan algorithm to scala and especially this part :

def tarjan_recursive(g):
        S = []
        S_set = set()
        index = {}
        lowlink = {}
        ret = []
        def visit(v):
                index[v] = len(index)
                lowlink[v] = index[v]
                S.append(v)
                S_set.add(v)

                for w in g.get(v,()):
                        print(w)
                        if w not in index:
                                visit(w)
                                lowlink[v] = min(lowlink[w], lowlink[v])
                        elif w in S_set:
                                lowlink[v] = min(lowlink[v], index[w])
                if lowlink[v] == index[v]:
                        scc = []
                        w = None
                        while v != w:
                                w = S.pop()
                                scc.append(w)
                                S_set.remove(w)
                        ret.append(scc)

        for v in g:
                print(index)
                if not v in index:
                        visit(v)
        return ret

I know that there's tarjan algorithm in scala here or here but it doesn't return good result and translate it from python help me understand it.

Here's what I have :

def tj_recursive(g: Map[Int,List[Int]])= {
        var s : mutable.ListBuffer[Int] = new mutable.ListBuffer()
        var s_set : mutable.Set[Int] = mutable.Set()
        var index : mutable.Map[Int,Int] =  mutable.Map()
        var lowlink : mutable.Map[Int,Int]=  mutable.Map()
        var ret : mutable.Map[Int,mutable.ListBuffer[Int]]= mutable.Map()

        def visit(v: Int):Int = {
                 index(v) = index.size
               lowlink(v) = index(v)
               var zz :List[Int]= gg.get(v).toList(0)
                            for( w <- zz) {
                  if( !(index.contains(w)) ){
                     visit(w)
                     lowlink(v) = List(lowlink(w),lowlink(v)).min
                   }else if(s_set.contains(w)){
                     lowlink(v)=List(lowlink(v),index(w)).min
                   }
               }

               if(lowlink(v)==index(v)){
                  var scc:mutable.ListBuffer[Int] = new mutable.ListBuffer()
                  var w:Int=null.asInstanceOf[Int]
                  while(v!=w){
                    w= s.last
                    scc+=w
                    s_set-=w
                  }
           ret+=scc
        }
        }

   for( v <- g) {if( !(index.contains(v)) ){visit(v)}}
   ret
}

I know this isn't the scala way at all (and not clean ...) but I'm planning to slowly change it to a more functional style when I get the first version working.

For now, I got this error :

type mismatch;  found   : Unit  required: Int

at this line

if(lowlink(v)==index(v)){ 

I think it's coming from this line but I'm not sure :

if( !(index.contains(w)) 

But it's really hard to debug it since I can't just println my mistakes ...

Thanks !


Solution

  • Here's a fairly literal translation of the Python:

    def tj_recursive(g: Map[Int, List[Int]])= {
      val s = mutable.Buffer.empty[Int]
      val s_set = mutable.Set.empty[Int]
      val index = mutable.Map.empty[Int, Int]
      val lowlink = mutable.Map.empty[Int, Int]
      val ret = mutable.Buffer.empty[mutable.Buffer[Int]]
    
      def visit(v: Int): Unit = {
        index(v) = index.size
        lowlink(v) = index(v)
        s += v
        s_set += v
    
        for (w <- g(v)) {
          if (!index.contains(w)) {
            visit(w)
            lowlink(v) = math.min(lowlink(w), lowlink(v))
          } else if (s_set(w)) {
            lowlink(v) = math.min(lowlink(v), index(w))
          }
        }
    
        if (lowlink(v) == index(v)) {
          val scc = mutable.Buffer.empty[Int]
          var w = -1
    
          while(v != w) {
            w = s.remove(s.size - 1)
            scc += w
            s_set -= w
          }
    
          ret += scc
        }
      }
    
      for (v <- g.keys) if (!index.contains(v)) visit(v)
      ret
    }
    

    It produces the same output on e.g.:

    tj_recursive(Map(
      1 -> List(2),    2 -> List(1, 5), 3 -> List(4),
      4 -> List(3, 5), 5 -> List(6),    6 -> List(7),
      7 -> List(8),    8 -> List(6, 9), 9 -> Nil
    ))
    

    The biggest problem with your implementation was the return type of visit (which should have been Unit, not Int) and the fact that you were iterating over the graph's items instead of the graph's keys in the final for-comprehension, but I've made a number of other edits for style and clarity (while still keeping the basic shape).