[Java] CountDownLatch를 이용해 thread 대기하기

반응형

Main 스레드에서 3개의 스레드(하나의 Producer 스레드와 두 개의 Consumer 스레드)를 생성하고 작업을 수행한 다음에 수행 시간을 출력해보자.

 

Main

public class Main {
    public static void main(String[] args) {
        Instant start = Instant.now();

        System.out.println("Start");
        int totalNumberOfTasks = 3;
        BlockingQueue<Integer> queue = new LinkedBlockingQueue<>(200);

        ExecutorService executorService = Executors.newFixedThreadPool(totalNumberOfTasks);
        executorService.submit(new Producer(queue));

        executorService.submit(new Consumer(queue));
        executorService.submit(new Consumer(queue));

        executorService.shutdown();

        Instant finish = Instant.now();
        long timeElapsed = Duration.between(start, finish).toMillis();

        System.out.println("Finished");
        System.out.println("Method took: " + timeElapsed + "ms");
    }
}

 

Consumer

public class Consumer implements Runnable {
    private final BlockingQueue<Integer> queue;

    public Consumer(BlockingQueue<Integer> queue) {
        this.queue = queue;
    }

    @Override
    public void run() {

        try {
            while (true) {
                Integer take = queue.take();
                if (take == -1) {
                    queue.put(-1);
                    break;
                }
                process(take);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
    }

    private void process(Integer take) throws InterruptedException {
        System.out.println("[Consumer]  Take : " + take);
        Thread.sleep(500);
    }

}

 

 

 

Producer

public class Producer implements Runnable {

    private final BlockingQueue<Integer> queue;

    public Producer(BlockingQueue<Integer> queue) {
        this.queue = queue;
    }

    @Override
    public void run() {

        try {
            process();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }

    }

    private void process() throws InterruptedException {
        for (int i = 0; i < 100; i++) {
            System.out.println("[Producer] Put : " + i);
            queue.put(i);
            System.out.println("[Producer] Queue remainingCapacity : " + queue.remainingCapacity());
            Thread.sleep(100);
        }

        queue.put(-1);
    }
}

 

 

결과는 아래와 같다. Main스레드가 다른 스레드의 작업이 모두 완료될때까지 기다리지 않고 결과를 출력한 것을 볼 수 있다.

Start
Finished
[Producer] Put : 0
[Producer] Queue remainingCapacity : 199
Method took: 10ms
[Consumer]  Take : 0
[Producer] Put : 1

...

[Consumer]  Take : 96
[Consumer]  Take : 97
[Consumer]  Take : 98
[Consumer]  Take : 99

작업 실행시간을 출력하려던 의도대로 작동하지 않았다. CountDownLatch 클래스를 이용해 이를 해결할 수 있다.

 

 

 

CountDownLatch 사용법

CountDownLatch는 어떤 쓰레드가 다른 스레드에서 작업이 완료될 때까지 기다릴 수 있도록 해주는 클래스이다.

 

 

 

아래 그림과 같이 TA는 T1,T2,T3 스레드를 호출하고 cnt가 0보다 크면 대기한다.

T1, T2, T2는 작업이 완료되면 countDown() 를 호출해 cnt 값을 감소시킨다. cnt 가 0이 되면 차단된 스레드 TA가 해제되어 나머지 작업을 수행한다.

출처:  https://www.programmersought.com/article/36987358114/

// latch의 숫자를 입력
CountDownLatch countDownLatch = new CountDownLatch(5);

// latch 숫자 감소 
countDownLatch.countDown();

// latch의 숫자가 0이 될 때까지 대기
countDownLatch.await();

 

 

 

다른 스레드 작업 대기하기

CountDownLatch 클래스를 이용해 Main 스레드가 다른 스레드의 작업이 완료될 때까지 기다리도록 코드를 수정해보자.

 

Main

public class Main {
    public static void main(String[] args) throws InterruptedException {
        Instant start = Instant.now();

        System.out.println("Start");
        int totalNumberOfTasks = 3;
        BlockingQueue<Integer> queue = new LinkedBlockingQueue<>(200);

        ExecutorService executorService = Executors.newFixedThreadPool(totalNumberOfTasks);
        CountDownLatch latch = new CountDownLatch(totalNumberOfTasks);

        executorService.submit(new Producer(queue, latch));

        executorService.submit(new Consumer(queue, latch));
        executorService.submit(new Consumer(queue, latch));

        executorService.shutdown();
        latch.await();

        Instant finish = Instant.now();
        long timeElapsed = Duration.between(start, finish).toMillis();

        System.out.println("Finished");
        System.out.println("Method took: " + timeElapsed + "ms");
    }
}

 

 

Producer

public class Producer implements Runnable {

    private final BlockingQueue<Integer> queue;
    private final CountDownLatch latch;

    public Producer(BlockingQueue<Integer> queue, CountDownLatch latch) {
        this.queue = queue;
        this.latch = latch;
    }

    @Override
    public void run() {

        try {
            process();
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }

        latch.countDown();
    }

    private void process() throws InterruptedException {
        for (int i = 0; i < 100; i++) {
            System.out.println("[Producer] Put : " + i);
            queue.put(i);
            System.out.println("[Producer] Queue remainingCapacity : " + queue.remainingCapacity());
            Thread.sleep(100);
        }

        queue.put(-1);
    }
}

 

 

Consumer

public class Consumer implements Runnable {

    private final BlockingQueue<Integer> queue;
    private final CountDownLatch latch;

    public Consumer(BlockingQueue<Integer> queue, CountDownLatch latch) {
        this.queue = queue;
        this.latch = latch;
    }

    @Override
    public void run() {

        try {
            while (true) {
                Integer take = queue.take();
                if (take == -1) {
                    queue.put(-1);
                    break;
                }
                process(take);
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }

        latch.countDown();
    }

    private void process(Integer take) throws InterruptedException {
        System.out.println("[Consumer]  Take : " + take);
        Thread.sleep(500);
    }

}

 

 

결과

Start
[Producer] Put : 0
[Producer] Queue remainingCapacity : 199
[Consumer]  Take : 0
[Producer] Put : 1

...

[Consumer]  Take : 196
[Consumer]  Take : 197
[Consumer]  Take : 198
[Consumer]  Take : 199
Finished
Method took: 33884ms