Search code examples
javarecursionconcurrencywaitcountdownlatch

Using CountDownLatch & Object.wait inside recursive block hangs


Problem: While trying to retrieve values inside a recursive block in a phased manner, the execution gets hung.

Description: CountDownLatch & Object.wait are used to achieve the phased manner access of value inside the recursive block. But, the program hangs with following output:

2 < 16
3 < 16
4 < 16
5 < 16
Current total: 5
 Inside of wait
 Inside of wait

Program:

import java.util.concurrent.*;
public class RecursiveTotalFinder {
    private static CountDownLatch latch1;
    private static CountDownLatch latch2;
    private static CountDownLatch latch3;
    public static void main(String... args) {
       latch1 = new CountDownLatch(1);
       latch2 = new CountDownLatch(1);
       latch3 = new CountDownLatch(1);

       //Create object
       TotalFinder tf = new TotalFinder(latch1,latch2,latch3);

       //Start the thread
       tf.start();

       //Wait for results from TotalFinder
       try {
           latch1.await();
       } catch(InterruptedException ie) {
           ie.printStackTrace();
       }

       //Print the result after 5th iteration
       System.out.println("Current total: "+tf.getCurrentTotal());
       tf.releaseWaitLock();
       tf.resetWaitLock();

       //Wait for results again
       try {
           latch2.await();
       } catch(InterruptedException ie) {
           ie.printStackTrace();
       }

       //Print the result after 10th iteration
       System.out.println("Current total: "+tf.getCurrentTotal());
       tf.releaseWaitLock();
       tf.resetWaitLock();

       //Wait for results again
       try {
           latch3.await();
       } catch(InterruptedException ie) {
           ie.printStackTrace();
       }

       //Print the result after 15th iteration
       System.out.println("Current total: "+tf.getCurrentTotal());
       tf.releaseWaitLock();
       tf.resetWaitLock();
    }
}


class TotalFinder extends Thread{
    CountDownLatch tfLatch1;
    CountDownLatch tfLatch2;
    CountDownLatch tfLatch3;
    private static int count = 1;
    private static final class Lock { }
    private final Object lock = new Lock();
    private boolean gotSignalFromMaster = false;

    public TotalFinder(CountDownLatch latch1, CountDownLatch latch2, 
                       CountDownLatch latch3) {
        tfLatch1 = latch1;
        tfLatch2 = latch2;
        tfLatch3 = latch3;
    }

    public void run() {
        findTotal(16);
    }

    //Find total
    synchronized void findTotal(int cnt) {
        if(count%5==0) {
           if(count==5)
              tfLatch1.countDown();
           if(count==10)
              tfLatch2.countDown();
           if(count==15)
              tfLatch3.countDown();

           //Sleep for sometime
           try {
               Thread.sleep(3000);
           } catch(InterruptedException ie) {
               ie.printStackTrace();
           }
           //Wait till current total is printed

           synchronized(lock) {
              while(gotSignalFromMaster==false) {
                 try {
                    System.out.println(" Inside of wait");
                    lock.wait();
                 } catch(InterruptedException ie) {
                    ie.printStackTrace();
                 }
              }
              System.out.println("Came outside of wait");
           }

        }
        count +=1;
        if(count < cnt) {
           System.out.println(count +" < "+cnt);
           findTotal(cnt);
        }
    }

    //Return the count value
    public int getCurrentTotal() {
       return count;
    }

    //Release lock
    public void releaseWaitLock() {
        //Sleep for sometime
        try {
            Thread.sleep(5000);
        } catch(InterruptedException ie) {
            ie.printStackTrace();
        }

        synchronized(lock) {
           gotSignalFromMaster=true;
           lock.notifyAll();
        }
    }

    //Reset wait lock
    public void resetWaitLock() {
        gotSignalFromMaster = false;
    }
}

Analysis: In my initial analysis it looks like the wait is happening recursively eventhough notifyAll is invoked from the main program.

Help: Why free lock using notfiyAll after a CountDownLatch didn't take effect? Need someone's help in understanding what exactly is happening in this program.


Solution

  • The main message about wait and notify that I got from JCIP was that I'd probably use them wrongly, so better to avoid using them directly unless strictly necessary. As such, I think that you should reconsider the use of these methods.

    In this case, I think that you can do it more elegantly using SynchronousQueue. Perhaps something like this might work:

    import java.util.concurrent.*;
    public class RecursiveTotalFinder {
        public static void main(String... args) throws InterruptedException {
           SynchronousQueue<Integer> syncQueue = new SynchronousQueue<>();
    
           //Create object
           TotalFinder tf = new TotalFinder(syncQueue, 5);
    
           //Start the thread
           tf.start();
    
           for (int i = 0; i < 3; ++i) {
             System.out.println("Current total: " + syncQueue.take());
           }
        }
    }
    
    class TotalFinder extends Thread{
      private final SynchronousQueue<Integer> syncQueue;
      private final int syncEvery;
      private int count;
    
      public TotalFinder(SynchronousQueue<Integer> syncQueue, 
                         int syncEvery) {
        this.syncQueue = syncQueue;
        this.syncEvery = syncEvery;
      }
    
      public void run() {
        try {
          findTotal(16);
        } catch (InterruptedException e) {
          Thread.currentThread().interrupt();
          throw new RuntimeException(e);
        }
      }
    
      //Find total
      void findTotal(int cnt) throws InterruptedException {
        if((count > 0) && (count%syncEvery==0)) {
          syncQueue.put(count);
        }
        count +=1;
        if(count < cnt) {
          System.out.println(count +" < "+cnt);
          findTotal(cnt);
        }
      }
    }
    

    As to why your original approach doesn't work, it's because the main thread sets gotSignalFromMaster to true and then immediately back to false, and this happens before the other thread is able to check its value. If you stick a bit of a sleep into the resetWaitLock, it proceeds beyond the point where it currently hangs; however, it then hangs at the end instead of terminating.

    Note that having to use Thread.sleep to wait for another thread to change some state is a poor approach - not least because it makes your program really slow. Using synchronization utilities leads to faster and much easier-to-reason-about program.