Pārlūkot izejas kodu

Correctly release threads from starting gate in o.e.c.c.CacheTests

Jason Tedor 9 gadi atpakaļ
vecāks
revīzija
35cc749c9a

+ 113 - 79
core/src/test/java/org/elasticsearch/common/cache/CacheTests.java

@@ -494,33 +494,41 @@ public class CacheTests extends ESTestCase {
     public void testComputeIfAbsentCallsOnce() throws InterruptedException {
         int numberOfThreads = randomIntBetween(2, 32);
         final Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
-        List<Thread> threads = new ArrayList<>();
         AtomicReferenceArray flags = new AtomicReferenceArray(numberOfEntries);
         for (int j = 0; j < numberOfEntries; j++) {
             flags.set(j, false);
         }
-        CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
+        CountDownLatch startGate = new CountDownLatch(1);
+        CountDownLatch endGate = new CountDownLatch(numberOfThreads);
+        AtomicBoolean interrupted = new AtomicBoolean();
         for (int i = 0; i < numberOfThreads; i++) {
             Thread thread = new Thread(() -> {
-                latch.countDown();
-                for (int j = 0; j < numberOfEntries; j++) {
+                try {
                     try {
-                        cache.computeIfAbsent(j, key -> {
-                            assertTrue(flags.compareAndSet(key, false, true));
-                            return Integer.toString(key);
-                        });
-                    } catch (ExecutionException e) {
-                        throw new RuntimeException(e);
+                        startGate.await();
+                    } catch (InterruptedException e) {
+                        interrupted.set(true);
+                        return;
                     }
+                    for (int j = 0; j < numberOfEntries; j++) {
+                        try {
+                            cache.computeIfAbsent(j, key -> {
+                                assertTrue(flags.compareAndSet(key, false, true));
+                                return Integer.toString(key);
+                            });
+                        } catch (ExecutionException e) {
+                            throw new RuntimeException(e);
+                        }
+                    }
+                } finally {
+                    endGate.countDown();
                 }
             });
-            threads.add(thread);
             thread.start();
         }
-        latch.countDown();
-        for (Thread thread : threads) {
-            thread.join();
-        }
+        startGate.countDown();
+        endGate.await();
+        assertFalse(interrupted.get());
     }
 
     public void testComputeIfAbsentThrowsExceptionIfLoaderReturnsANullValue() {
@@ -560,30 +568,39 @@ public class CacheTests extends ESTestCase {
 
         int numberOfThreads = randomIntBetween(2, 32);
         final Cache<Key, Integer> cache = CacheBuilder.<Key, Integer>builder().build();
-        CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
+        CountDownLatch startGate = new CountDownLatch(1);
         CountDownLatch deadlockLatch = new CountDownLatch(numberOfThreads);
+        AtomicBoolean interrupted = new AtomicBoolean();
         List<Thread> threads = new ArrayList<>();
         for (int i = 0; i < numberOfThreads; i++) {
             Thread thread = new Thread(() -> {
-                Random random = new Random(random().nextLong());
-                latch.countDown();
-                for (int j = 0; j < numberOfEntries; j++) {
-                    Key key = new Key(random.nextInt(numberOfEntries));
+                try {
                     try {
-                        cache.computeIfAbsent(key, k -> {
-                            if (k.key == 0) {
-                                return 0;
-                            } else {
-                                Integer value = cache.get(new Key(k.key / 2));
-                                return value != null ? value : 0;
-                            }
-                        });
-                    } catch (ExecutionException e) {
-                        fail(e.getMessage());
+                        startGate.await();
+                    } catch (InterruptedException e) {
+                        interrupted.set(true);
+                        return;
                     }
+                    Random random = new Random(random().nextLong());
+                    for (int j = 0; j < numberOfEntries; j++) {
+                        Key key = new Key(random.nextInt(numberOfEntries));
+                        try {
+                            cache.computeIfAbsent(key, k -> {
+                                if (k.key == 0) {
+                                    return 0;
+                                } else {
+                                    Integer value = cache.get(new Key(k.key / 2));
+                                    return value != null ? value : 0;
+                                }
+                            });
+                        } catch (ExecutionException e) {
+                            fail(e.getMessage());
+                        }
+                    }
+                } finally {
+                    // successfully avoided deadlock, release the main thread
+                    deadlockLatch.countDown();
                 }
-                // successfully avoided deadlock, release the main thread
-                deadlockLatch.countDown();
             });
             threads.add(thread);
             thread.start();
@@ -614,7 +631,7 @@ public class CacheTests extends ESTestCase {
         }, 1, 1, TimeUnit.SECONDS);
 
         // everything is setup, release the hounds
-        latch.countDown();
+        startGate.countDown();
 
         // wait for either deadlock to be detected or the threads to terminate
         deadlockLatch.await();
@@ -628,49 +645,57 @@ public class CacheTests extends ESTestCase {
     public void testCachePollution() throws InterruptedException {
         int numberOfThreads = randomIntBetween(2, 32);
         final Cache<Integer, String> cache = CacheBuilder.<Integer, String>builder().build();
-        CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
-        List<Thread> threads = new ArrayList<>();
+        CountDownLatch startGate = new CountDownLatch(1);
+        CountDownLatch endGate = new CountDownLatch(numberOfThreads);
+        AtomicBoolean interrupted = new AtomicBoolean();
         for (int i = 0; i < numberOfThreads; i++) {
             Thread thread = new Thread(() -> {
-                latch.countDown();
-                Random random = new Random(random().nextLong());
-                for (int j = 0; j < numberOfEntries; j++) {
-                    Integer key = random.nextInt(numberOfEntries);
-                    boolean first;
-                    boolean second;
-                    do {
-                        first = random.nextBoolean();
-                        second = random.nextBoolean();
-                    } while (first && second);
-                    if (first) {
-                        try {
-                            cache.computeIfAbsent(key, k -> {
-                                if (random.nextBoolean()) {
-                                    return Integer.toString(k);
-                                } else {
-                                    throw new Exception("testCachePollution");
-                                }
-                            });
-                        } catch (ExecutionException e) {
-                            assertNotNull(e.getCause());
-                            assertThat(e.getCause(), instanceOf(Exception.class));
-                            assertEquals(e.getCause().getMessage(), "testCachePollution");
+                try {
+                    try {
+                        startGate.await();
+                    } catch (InterruptedException e) {
+                        interrupted.set(true);
+                        return;
+                    }
+                    Random random = new Random(random().nextLong());
+                    for (int j = 0; j < numberOfEntries; j++) {
+                        Integer key = random.nextInt(numberOfEntries);
+                        boolean first;
+                        boolean second;
+                        do {
+                            first = random.nextBoolean();
+                            second = random.nextBoolean();
+                        } while (first && second);
+                        if (first) {
+                            try {
+                                cache.computeIfAbsent(key, k -> {
+                                    if (random.nextBoolean()) {
+                                        return Integer.toString(k);
+                                    } else {
+                                        throw new Exception("testCachePollution");
+                                    }
+                                });
+                            } catch (ExecutionException e) {
+                                assertNotNull(e.getCause());
+                                assertThat(e.getCause(), instanceOf(Exception.class));
+                                assertEquals(e.getCause().getMessage(), "testCachePollution");
+                            }
+                        } else if (second) {
+                            cache.invalidate(key);
+                        } else {
+                            cache.get(key);
                         }
-                    } else if (second) {
-                        cache.invalidate(key);
-                    } else {
-                        cache.get(key);
                     }
+                } finally {
+                    endGate.countDown();
                 }
             });
-            threads.add(thread);
             thread.start();
         }
 
-        latch.countDown();
-        for (Thread thread : threads) {
-            thread.join();
-        }
+        startGate.countDown();
+        endGate.await();
+        assertFalse(interrupted.get());
     }
 
     // test that the cache is not corrupted under lots of concurrent modifications, even hitting the same key
@@ -683,24 +708,33 @@ public class CacheTests extends ESTestCase {
                         .weigher((k, v) -> 2)
                         .build();
 
-        CountDownLatch latch = new CountDownLatch(1 + numberOfThreads);
-        List<Thread> threads = new ArrayList<>();
+        CountDownLatch startGate = new CountDownLatch(1);
+        CountDownLatch endGate = new CountDownLatch(numberOfThreads);
+        AtomicBoolean interrupted = new AtomicBoolean();
         for (int i = 0; i < numberOfThreads; i++) {
             Thread thread = new Thread(() -> {
-                Random random = new Random(random().nextLong());
-                latch.countDown();
-                for (int j = 0; j < numberOfEntries; j++) {
-                    Integer key = random.nextInt(numberOfEntries);
-                    cache.put(key, Integer.toString(j));
+                try {
+                    try {
+                        startGate.await();
+                    } catch (InterruptedException e) {
+                        interrupted.set(true);
+                        return;
+                    }
+                    Random random = new Random(random().nextLong());
+                    for (int j = 0; j < numberOfEntries; j++) {
+                        Integer key = random.nextInt(numberOfEntries);
+                        cache.put(key, Integer.toString(j));
+                    }
+                } finally {
+                    endGate.countDown();
                 }
             });
-            threads.add(thread);
             thread.start();
         }
-        latch.countDown();
-        for (Thread thread : threads) {
-            thread.join();
-        }
+        startGate.countDown();
+        endGate.await();
+        assertFalse(interrupted.get());
+
         cache.refresh();
         assertEquals(500, cache.count());
     }