Browse Source

Introduce RefCounted#mustIncRef (#102515)

In several places we acquire a ref to a resource that we are certain is
not closed, so this commit adds a utility for asserting this to be the
case. This also helps a little with mocks since boolean methods like
`tryIncRef()` return `false` on mock objects by default, but void
methods like `mustIncRef()` default to being a no-op.
David Turner 1 year ago
parent
commit
b2127ec2f9
23 changed files with 51 additions and 62 deletions
  1. 2 1
      libs/core/src/main/java/org/elasticsearch/core/AbstractRefCounted.java
  2. 12 0
      libs/core/src/main/java/org/elasticsearch/core/RefCounted.java
  3. 7 21
      modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java
  4. 1 1
      modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java
  5. 5 9
      server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java
  6. 2 2
      server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java
  7. 1 1
      server/src/main/java/org/elasticsearch/action/support/master/TransportMasterNodeAction.java
  8. 1 1
      server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java
  9. 1 2
      server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java
  10. 1 2
      server/src/main/java/org/elasticsearch/cluster/coordination/JoinValidationService.java
  11. 1 1
      server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java
  12. 2 2
      server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java
  13. 1 1
      server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java
  14. 2 2
      server/src/main/java/org/elasticsearch/transport/InboundHandler.java
  15. 1 1
      server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java
  16. 2 2
      server/src/main/java/org/elasticsearch/transport/TransportService.java
  17. 3 2
      server/src/test/java/org/elasticsearch/action/support/RefCountingListenerTests.java
  18. 3 2
      server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java
  19. 1 1
      test/framework/src/main/java/org/elasticsearch/transport/DisruptableMockTransport.java
  20. 1 4
      x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIteratorTests.java
  21. 1 1
      x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java
  22. 0 1
      x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherServiceTests.java
  23. 0 2
      x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStoreTests.java

+ 2 - 1
libs/core/src/main/java/org/elasticsearch/core/AbstractRefCounted.java

@@ -19,6 +19,7 @@ import java.util.Objects;
 public abstract class AbstractRefCounted implements RefCounted {
 
     public static final String ALREADY_CLOSED_MESSAGE = "already closed, can't increment ref count";
+    public static final String INVALID_DECREF_MESSAGE = "invalid decRef call: already closed";
 
     private static final VarHandle VH_REFCOUNT_FIELD;
 
@@ -63,7 +64,7 @@ public abstract class AbstractRefCounted implements RefCounted {
     public final boolean decRef() {
         touch();
         int i = (int) VH_REFCOUNT_FIELD.getAndAdd(this, -1);
-        assert i > 0 : "invalid decRef call: already closed";
+        assert i > 0 : INVALID_DECREF_MESSAGE;
         if (i == 1) {
             try {
                 closeInternal();

+ 12 - 0
libs/core/src/main/java/org/elasticsearch/core/RefCounted.java

@@ -62,4 +62,16 @@ public interface RefCounted {
      * @return whether there are currently any active references to this object.
      */
     boolean hasReferences();
+
+    /**
+     * Similar to {@link #incRef()} except that it also asserts that it managed to acquire the ref, for use in situations where it is a bug
+     * if all refs have been released.
+     */
+    default void mustIncRef() {
+        if (tryIncRef()) {
+            return;
+        }
+        assert false : AbstractRefCounted.ALREADY_CLOSED_MESSAGE;
+        incRef(); // throws an ISE
+    }
 }

+ 7 - 21
modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java

@@ -180,15 +180,11 @@ public class GeoIpDownloaderTests extends ESTestCase {
     public void testIndexChunksNoData() throws IOException {
         client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
             assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
-            var flushResponse = mock(FlushResponse.class);
-            when(flushResponse.hasReferences()).thenReturn(true);
-            flushResponseActionListener.onResponse(flushResponse);
+            flushResponseActionListener.onResponse(mock(FlushResponse.class));
         });
         client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
             assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
-            var refreshResponse = mock(RefreshResponse.class);
-            when(refreshResponse.hasReferences()).thenReturn(true);
-            flushResponseActionListener.onResponse(refreshResponse);
+            flushResponseActionListener.onResponse(mock(RefreshResponse.class));
         });
 
         InputStream empty = new ByteArrayInputStream(new byte[0]);
@@ -198,15 +194,11 @@ public class GeoIpDownloaderTests extends ESTestCase {
     public void testIndexChunksMd5Mismatch() {
         client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
             assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
-            var flushResponse = mock(FlushResponse.class);
-            when(flushResponse.hasReferences()).thenReturn(true);
-            flushResponseActionListener.onResponse(flushResponse);
+            flushResponseActionListener.onResponse(mock(FlushResponse.class));
         });
         client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
             assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
-            var refreshResponse = mock(RefreshResponse.class);
-            when(refreshResponse.hasReferences()).thenReturn(true);
-            flushResponseActionListener.onResponse(refreshResponse);
+            flushResponseActionListener.onResponse(mock(RefreshResponse.class));
         });
 
         IOException exception = expectThrows(
@@ -238,21 +230,15 @@ public class GeoIpDownloaderTests extends ESTestCase {
             assertEquals("test", source.get("name"));
             assertArrayEquals(chunksData[chunk], (byte[]) source.get("data"));
             assertEquals(chunk + 15, source.get("chunk"));
-            var indexResponse = mock(IndexResponse.class);
-            when(indexResponse.hasReferences()).thenReturn(true);
-            listener.onResponse(indexResponse);
+            listener.onResponse(mock(IndexResponse.class));
         });
         client.addHandler(FlushAction.INSTANCE, (FlushRequest request, ActionListener<FlushResponse> flushResponseActionListener) -> {
             assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
-            var flushResponse = mock(FlushResponse.class);
-            when(flushResponse.hasReferences()).thenReturn(true);
-            flushResponseActionListener.onResponse(flushResponse);
+            flushResponseActionListener.onResponse(mock(FlushResponse.class));
         });
         client.addHandler(RefreshAction.INSTANCE, (RefreshRequest request, ActionListener<RefreshResponse> flushResponseActionListener) -> {
             assertArrayEquals(new String[] { GeoIpDownloader.DATABASES_INDEX }, request.indices());
-            var refreshResponse = mock(RefreshResponse.class);
-            when(refreshResponse.hasReferences()).thenReturn(true);
-            flushResponseActionListener.onResponse(refreshResponse);
+            flushResponseActionListener.onResponse(mock(RefreshResponse.class));
         });
 
         InputStream big = new ByteArrayInputStream(bigArray);

+ 1 - 1
modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Service.java

@@ -135,7 +135,7 @@ class S3Service implements Closeable {
                 return existing;
             }
             final AmazonS3Reference clientReference = new AmazonS3Reference(buildClient(clientSettings));
-            clientReference.incRef();
+            clientReference.mustIncRef();
             clientsCache = Maps.copyMapWithAddedEntry(clientsCache, clientSettings, clientReference);
             return clientReference;
         }

+ 5 - 9
server/src/main/java/org/elasticsearch/action/support/RefCountingRunnable.java

@@ -63,7 +63,6 @@ import org.elasticsearch.core.Releasables;
 public final class RefCountingRunnable implements Releasable {
 
     private static final Logger logger = LogManager.getLogger(RefCountingRunnable.class);
-    static final String ALREADY_CLOSED_MESSAGE = "already closed, cannot acquire or release any further refs";
 
     private final RefCounted refCounted;
 
@@ -86,14 +85,11 @@ public final class RefCountingRunnable implements Releasable {
      * will be ignored otherwise. This deviates from the contract of {@link java.io.Closeable}.
      */
     public Releasable acquire() {
-        if (refCounted.tryIncRef()) {
-            // All refs are considered equal so there's no real need to allocate a new object here, although note that this deviates
-            // (subtly) from the docs for Closeable#close() which indicate that it should be idempotent. But only if assertions are
-            // disabled, and if assertions are enabled then we are asserting that we never double-close these things anyway.
-            return Releasables.assertOnce(this);
-        }
-        assert false : ALREADY_CLOSED_MESSAGE;
-        throw new IllegalStateException(ALREADY_CLOSED_MESSAGE);
+        refCounted.mustIncRef();
+        // All refs are considered equal so there's no real need to allocate a new object here, although note that this deviates (subtly)
+        // from the docs for Closeable#close() which indicate that it should be idempotent. But only if assertions are disabled, and if
+        // assertions are enabled then we are asserting that we never double-close these things anyway.
+        return Releasables.assertOnce(this);
     }
 
     /**

+ 2 - 2
server/src/main/java/org/elasticsearch/action/support/broadcast/node/TransportBroadcastByNodeAction.java

@@ -228,7 +228,7 @@ public abstract class TransportBroadcastByNodeAction<
     @Override
     protected void doExecute(Task task, Request request, ActionListener<Response> listener) {
         // workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can
-        request.incRef();
+        request.mustIncRef();
         executor.execute(ActionRunnable.wrapReleasing(listener, request::decRef, l -> doExecuteForked(task, request, listener)));
     }
 
@@ -474,7 +474,7 @@ public abstract class TransportBroadcastByNodeAction<
         }
 
         NodeRequest(Request indicesLevelRequest, List<ShardRouting> shards, String nodeId) {
-            indicesLevelRequest.incRef();
+            indicesLevelRequest.mustIncRef();
             this.indicesLevelRequest = indicesLevelRequest;
             this.shards = shards;
             this.nodeId = nodeId;

+ 1 - 1
server/src/main/java/org/elasticsearch/action/support/master/TransportMasterNodeAction.java

@@ -169,7 +169,7 @@ public abstract class TransportMasterNodeAction<Request extends MasterNodeReques
         if (task != null) {
             request.setParentTask(clusterService.localNode().getId(), task.getId());
         }
-        request.incRef();
+        request.mustIncRef();
         new AsyncSingleAction(task, request, ActionListener.runBefore(listener, request::decRef)).doStart(state);
     }
 

+ 1 - 1
server/src/main/java/org/elasticsearch/action/support/tasks/TransportTasksAction.java

@@ -290,7 +290,7 @@ public abstract class TransportTasksAction<
 
         protected NodeTaskRequest(TasksRequest tasksRequest) {
             super();
-            tasksRequest.incRef();
+            tasksRequest.mustIncRef();
             this.tasksRequest = tasksRequest;
         }
 

+ 1 - 2
server/src/main/java/org/elasticsearch/client/internal/support/AbstractClient.java

@@ -1612,9 +1612,8 @@ public abstract class AbstractClient implements Client {
 
         @Override
         public final void onResponse(R result) {
-            assert result.hasReferences();
             if (set(result)) {
-                result.incRef();
+                result.mustIncRef();
             }
         }
 

+ 1 - 2
server/src/main/java/org/elasticsearch/cluster/coordination/JoinValidationService.java

@@ -363,8 +363,7 @@ public class JoinValidationService {
                 );
                 return;
             }
-            assert bytes.hasReferences() : "already closed";
-            bytes.incRef();
+            bytes.mustIncRef();
             transportService.sendRequest(
                 connection,
                 JOIN_VALIDATE_ACTION_NAME,

+ 1 - 1
server/src/main/java/org/elasticsearch/common/util/CancellableSingleObjectCache.java

@@ -192,7 +192,7 @@ public abstract class CancellableSingleObjectCache<Input, Key, Value> {
 
         CachedItem(Key key) {
             this.key = key;
-            incRef(); // start with a refcount of 2 so we're not closed while adding the first listener
+            mustIncRef(); // start with a refcount of 2 so we're not closed while adding the first listener
             this.future.addListener(new ActionListener<>() {
                 @Override
                 public void onResponse(Value value) {

+ 2 - 2
server/src/main/java/org/elasticsearch/common/util/concurrent/ThrottledIterator.java

@@ -88,7 +88,7 @@ public class ThrottledIterator<T> implements Releasable {
                 }
             }
             try (var itemRefs = new ItemRefCounted()) {
-                itemRefs.incRef();
+                itemRefs.mustIncRef();
                 itemConsumer.accept(Releasables.releaseOnce(itemRefs::decRef), item);
             } catch (Exception e) {
                 logger.error(Strings.format("exception when processing [%s] with [%s]", item, itemConsumer), e);
@@ -108,7 +108,7 @@ public class ThrottledIterator<T> implements Releasable {
         private boolean isRecursive = true;
 
         ItemRefCounted() {
-            refs.incRef();
+            refs.mustIncRef();
         }
 
         @Override

+ 1 - 1
server/src/main/java/org/elasticsearch/transport/ClusterConnectionManager.java

@@ -223,7 +223,7 @@ public class ClusterConnectionManager implements ConnectionManager {
                             IOUtils.closeWhileHandlingException(conn);
                         } else {
                             logger.debug("connected to node [{}]", node);
-                            managerRefs.incRef();
+                            managerRefs.mustIncRef();
                             try {
                                 connectionListener.onNodeConnected(node, conn);
                             } finally {

+ 2 - 2
server/src/main/java/org/elasticsearch/transport/InboundHandler.java

@@ -293,7 +293,7 @@ public class InboundHandler {
 
     private <T extends TransportRequest> void handleRequestForking(T request, RequestHandlerRegistry<T> reg, TransportChannel channel) {
         boolean success = false;
-        request.incRef();
+        request.mustIncRef();
         try {
             reg.getExecutor().execute(threadPool.getThreadContext().preserveContextWithTracing(new AbstractRunnable() {
                 @Override
@@ -381,7 +381,7 @@ public class InboundHandler {
             // no need to provide a buffer release here, we never escape the buffer when handling directly
             doHandleResponse(handler, remoteAddress, stream, inboundMessage.getHeader(), () -> {});
         } else {
-            inboundMessage.incRef();
+            inboundMessage.mustIncRef();
             // release buffer once we deserialize the message, but have a fail-safe in #onAfter below in case that didn't work out
             final Releasable releaseBuffer = Releasables.releaseOnce(inboundMessage::decRef);
             executor.execute(new ForkingResponseHandlerRunnable(handler, null, threadPool) {

+ 1 - 1
server/src/main/java/org/elasticsearch/transport/TransportActionProxy.java

@@ -65,7 +65,7 @@ public final class TransportActionProxy {
                 @Override
                 public void handleResponse(TransportResponse response) {
                     try {
-                        response.incRef();
+                        response.mustIncRef();
                         channel.sendResponse(response);
                     } catch (IOException e) {
                         throw new UncheckedIOException(e);

+ 2 - 2
server/src/main/java/org/elasticsearch/transport/TransportService.java

@@ -1013,7 +1013,7 @@ public class TransportService extends AbstractLifecycleComponent
                 }
             } else {
                 boolean success = false;
-                request.incRef();
+                request.mustIncRef();
                 try {
                     executor.execute(threadPool.getThreadContext().preserveContextWithTracing(new AbstractRunnable() {
                         @Override
@@ -1479,7 +1479,7 @@ public class TransportService extends AbstractLifecycleComponent
                     if (executor == EsExecutors.DIRECT_EXECUTOR_SERVICE) {
                         processResponse(handler, response);
                     } else {
-                        response.incRef();
+                        response.mustIncRef();
                         executor.execute(new ForkingResponseHandlerRunnable(handler, null, threadPool) {
                             @Override
                             protected void doRun() {

+ 3 - 2
server/src/test/java/org/elasticsearch/action/support/RefCountingListenerTests.java

@@ -11,6 +11,7 @@ package org.elasticsearch.action.support;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.common.util.concurrent.RunOnce;
+import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.CheckedConsumer;
 import org.elasticsearch.test.ESTestCase;
 import org.elasticsearch.test.ReachabilityChecker;
@@ -174,10 +175,10 @@ public class RefCountingListenerTests extends ESTestCase {
             final String expectedMessage;
             if (randomBoolean()) {
                 throwingRunnable = refs::acquire;
-                expectedMessage = RefCountingRunnable.ALREADY_CLOSED_MESSAGE;
+                expectedMessage = AbstractRefCounted.ALREADY_CLOSED_MESSAGE;
             } else {
                 throwingRunnable = refs::close;
-                expectedMessage = "invalid decRef call: already closed";
+                expectedMessage = AbstractRefCounted.INVALID_DECREF_MESSAGE;
             }
 
             assertEquals(expectedMessage, expectThrows(AssertionError.class, throwingRunnable).getMessage());

+ 3 - 2
server/src/test/java/org/elasticsearch/action/support/RefCountingRunnableTests.java

@@ -13,6 +13,7 @@ import org.elasticsearch.common.util.concurrent.AbstractRunnable;
 import org.elasticsearch.common.util.concurrent.EsExecutors;
 import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
+import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.test.ESTestCase;
 
@@ -166,10 +167,10 @@ public class RefCountingRunnableTests extends ESTestCase {
             final String expectedMessage;
             if (randomBoolean()) {
                 throwingRunnable = randomBoolean() ? refs::acquire : refs::acquireListener;
-                expectedMessage = RefCountingRunnable.ALREADY_CLOSED_MESSAGE;
+                expectedMessage = AbstractRefCounted.ALREADY_CLOSED_MESSAGE;
             } else {
                 throwingRunnable = refs::close;
-                expectedMessage = "invalid decRef call: already closed";
+                expectedMessage = AbstractRefCounted.INVALID_DECREF_MESSAGE;
             }
 
             assertEquals(expectedMessage, expectThrows(AssertionError.class, throwingRunnable).getMessage());

+ 1 - 1
test/framework/src/main/java/org/elasticsearch/transport/DisruptableMockTransport.java

@@ -150,7 +150,7 @@ public abstract class DisruptableMockTransport extends MockTransport {
         assert destinationTransport.getLocalNode().equals(getLocalNode()) == false
             : "non-local message from " + getLocalNode() + " to itself";
 
-        request.incRef();
+        request.mustIncRef();
 
         destinationTransport.execute(new RebootSensitiveRunnable() {
             @Override

+ 1 - 4
x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/utils/persistence/BatchedDocumentsIteratorTests.java

@@ -140,9 +140,7 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
         doAnswer(invocationOnMock -> {
             ActionListener<ClearScrollResponse> listener = (ActionListener<ClearScrollResponse>) invocationOnMock.getArguments()[2];
             wasScrollCleared = true;
-            var clearScrollResponse = mock(ClearScrollResponse.class);
-            when(clearScrollResponse.hasReferences()).thenReturn(true);
-            listener.onResponse(clearScrollResponse);
+            listener.onResponse(mock(ClearScrollResponse.class));
             return null;
         }).when(client).execute(eq(ClearScrollAction.INSTANCE), any(), any());
     }
@@ -173,7 +171,6 @@ public class BatchedDocumentsIteratorTests extends ESTestCase {
         protected SearchResponse createSearchResponseWithHits(String... hits) {
             SearchHits searchHits = createHits(hits);
             SearchResponse searchResponse = mock(SearchResponse.class);
-            when(searchResponse.hasReferences()).thenReturn(true);
             when(searchResponse.getScrollId()).thenReturn(SCROLL_ID);
             when(searchResponse.getHits()).thenReturn(searchHits);
             return searchResponse;

+ 1 - 1
x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/transport/SecurityServerTransportInterceptor.java

@@ -543,7 +543,7 @@ public class SecurityServerTransportInterceptor implements TransportInterceptor
 
         AbstractRunnable getReceiveRunnable(T request, TransportChannel channel, Task task) {
             final Runnable releaseRequest = new RunOnce(request::decRef);
-            request.incRef();
+            request.mustIncRef();
             return new AbstractRunnable() {
                 @Override
                 public boolean isForceExecution() {

+ 0 - 1
x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherServiceTests.java

@@ -163,7 +163,6 @@ public class WatcherServiceTests extends ESTestCase {
 
         // response setup, successful refresh response
         RefreshResponse refreshResponse = mock(RefreshResponse.class);
-        when(refreshResponse.hasReferences()).thenReturn(true);
         when(refreshResponse.getSuccessfulShards()).thenReturn(
             clusterState.getMetadata().getIndices().get(Watch.INDEX).getNumberOfShards()
         );

+ 0 - 2
x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/execution/TriggeredWatchStoreTests.java

@@ -210,7 +210,6 @@ public class TriggeredWatchStoreTests extends ESTestCase {
         SearchResponse searchResponse1 = mock(SearchResponse.class);
         when(searchResponse1.getSuccessfulShards()).thenReturn(1);
         when(searchResponse1.getTotalShards()).thenReturn(1);
-        when(searchResponse1.hasReferences()).thenReturn(true);
         BytesArray source = new BytesArray("{}");
         SearchHit hit = new SearchHit(0, "first_foo");
         hit.version(1L);
@@ -513,7 +512,6 @@ public class TriggeredWatchStoreTests extends ESTestCase {
         RefreshResponse refreshResponse = mock(RefreshResponse.class);
         when(refreshResponse.getTotalShards()).thenReturn(total);
         when(refreshResponse.getSuccessfulShards()).thenReturn(successful);
-        when(refreshResponse.hasReferences()).thenReturn(true);
         return refreshResponse;
     }
 }