Browse Source

Misc improvements to AbstractRefCounted (#92616)

Adds a null check and a `toString()` implementation which passes through
to the wrapped runnable. Also renames `RefCountedTests` to
`AbstractRefCountedTests` since they're really all about testing this
specific implementation.
David Turner 2 years ago
parent
commit
eb8cb109a4

+ 8 - 0
libs/core/src/main/java/org/elasticsearch/core/AbstractRefCounted.java

@@ -8,6 +8,7 @@
 
 package org.elasticsearch.core;
 
+import java.util.Objects;
 import java.util.concurrent.atomic.AtomicInteger;
 
 /**
@@ -94,11 +95,18 @@ public abstract class AbstractRefCounted implements RefCounted {
      * Construct an {@link AbstractRefCounted} which runs the given {@link Runnable} when all references are released.
      */
     public static AbstractRefCounted of(Runnable onClose) {
+        Objects.requireNonNull(onClose);
         return new AbstractRefCounted() {
             @Override
             protected void closeInternal() {
                 onClose.run();
             }
+
+            @Override
+            public String toString() {
+                return "refCounted[" + onClose + "]";
+            }
         };
     }
+
 }

+ 35 - 34
libs/core/src/test/java/org/elasticsearch/common/util/concurrent/RefCountedTests.java → libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java

@@ -5,23 +5,20 @@
  * in compliance with, at your election, the Elastic License 2.0 or the Server
  * Side Public License, v 1.
  */
-package org.elasticsearch.common.util.concurrent;
+package org.elasticsearch.core;
 
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.test.ESTestCase;
-import org.hamcrest.Matchers;
 
-import java.util.concurrent.CopyOnWriteArrayList;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
 
-public class RefCountedTests extends ESTestCase {
+public class AbstractRefCountedTests extends ESTestCase {
 
     public void testRefCount() {
-        MyRefCounted counted = new MyRefCounted();
+        final RefCounted counted = createRefCounted();
 
         int incs = randomIntBetween(1, 100);
         for (int i = 0; i < incs; i++) {
@@ -30,12 +27,12 @@ public class RefCountedTests extends ESTestCase {
             } else {
                 assertTrue(counted.tryIncRef());
             }
-            counted.ensureOpen();
+            assertTrue(counted.hasReferences());
         }
 
         for (int i = 0; i < incs; i++) {
             counted.decRef();
-            counted.ensureOpen();
+            assertTrue(counted.hasReferences());
         }
 
         counted.incRef();
@@ -46,12 +43,12 @@ public class RefCountedTests extends ESTestCase {
             } else {
                 assertTrue(counted.tryIncRef());
             }
-            counted.ensureOpen();
+            assertTrue(counted.hasReferences());
         }
 
         for (int i = 0; i < incs; i++) {
             counted.decRef();
-            counted.ensureOpen();
+            assertTrue(counted.hasReferences());
         }
 
         counted.decRef();
@@ -60,29 +57,29 @@ public class RefCountedTests extends ESTestCase {
             expectThrows(IllegalStateException.class, counted::incRef).getMessage(),
             equalTo(AbstractRefCounted.ALREADY_CLOSED_MESSAGE)
         );
-        assertThat(expectThrows(IllegalStateException.class, counted::ensureOpen).getMessage(), equalTo("closed"));
+        assertFalse(counted.hasReferences());
     }
 
     public void testMultiThreaded() throws InterruptedException {
-        final MyRefCounted counted = new MyRefCounted();
-        Thread[] threads = new Thread[randomIntBetween(2, 5)];
+        final AbstractRefCounted counted = createRefCounted();
+        final Thread[] threads = new Thread[randomIntBetween(2, 5)];
         final CountDownLatch latch = new CountDownLatch(1);
-        final CopyOnWriteArrayList<Exception> exceptions = new CopyOnWriteArrayList<>();
         for (int i = 0; i < threads.length; i++) {
             threads[i] = new Thread(() -> {
                 try {
                     latch.await();
                     for (int j = 0; j < 10000; j++) {
-                        counted.incRef();
                         assertTrue(counted.hasReferences());
-                        try {
-                            counted.ensureOpen();
-                        } finally {
-                            counted.decRef();
+                        if (randomBoolean()) {
+                            counted.incRef();
+                        } else {
+                            assertTrue(counted.tryIncRef());
                         }
+                        assertTrue(counted.hasReferences());
+                        counted.decRef();
                     }
                 } catch (Exception e) {
-                    exceptions.add(e);
+                    throw new AssertionError(e);
                 }
             });
             threads[i].start();
@@ -92,31 +89,35 @@ public class RefCountedTests extends ESTestCase {
             thread.join();
         }
         counted.decRef();
-        assertThat(expectThrows(IllegalStateException.class, counted::ensureOpen).getMessage(), equalTo("closed"));
+        assertFalse(counted.hasReferences());
         assertThat(
             expectThrows(IllegalStateException.class, counted::incRef).getMessage(),
             equalTo(AbstractRefCounted.ALREADY_CLOSED_MESSAGE)
         );
         assertThat(counted.refCount(), is(0));
         assertFalse(counted.hasReferences());
-        assertThat(exceptions, Matchers.emptyIterable());
     }
 
-    private static final class MyRefCounted extends AbstractRefCounted {
+    public void testToString() {
+        assertEquals("refCounted[runnable description]", createRefCounted().toString());
+    }
 
-        private final AtomicBoolean closed = new AtomicBoolean(false);
+    public void testNullCheck() {
+        expectThrows(NullPointerException.class, () -> AbstractRefCounted.of(null));
+    }
 
-        @Override
-        protected void closeInternal() {
-            this.closed.set(true);
-        }
+    private static AbstractRefCounted createRefCounted() {
+        final var closed = new AtomicBoolean();
+        return AbstractRefCounted.of(new Runnable() {
+            @Override
+            public void run() {
+                assertTrue(closed.compareAndSet(false, true));
+            }
 
-        public void ensureOpen() {
-            if (closed.get()) {
-                assertEquals(0, this.refCount());
-                assertFalse(hasReferences());
-                throw new IllegalStateException("closed");
+            @Override
+            public String toString() {
+                return "runnable description";
             }
-        }
+        });
     }
 }