Browse Source

Faster ref-count logic for when ref-counted object does not escape (#105338)

Introducing a plain version of `AbstractRefCounted` as a compromise.
This saves a bunch of allocations and a circular reference to the object
holding the ref counted instance, making smaller SearchHit instances
etc. cheaper. We could get an even more direct solution here by making
these extend `AbstractRefCounted` but that would lose us the ability to
leak-track in tests, so doing it this way (same way Netty does it on
their end) as a compromise.
Armin Braun 1 year ago
parent
commit
842915701d

+ 18 - 0
libs/core/src/main/java/org/elasticsearch/core/SimpleRefCounted.java

@@ -0,0 +1,18 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.core;
+
+/**
+ * {@link RefCounted} which does nothing when all references are released. It is the responsibility of the caller
+ * to run whatever release logic should be executed when {@link AbstractRefCounted#decRef()} returns true.
+ */
+public class SimpleRefCounted extends AbstractRefCounted {
+    @Override
+    protected void closeInternal() {}
+}

+ 18 - 15
server/src/main/java/org/elasticsearch/action/search/MultiSearchResponse.java

@@ -19,9 +19,9 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.core.SimpleRefCounted;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.transport.LeakTracker;
 import org.elasticsearch.xcontent.ToXContent;
@@ -122,19 +122,7 @@ public class MultiSearchResponse extends ActionResponse implements Iterable<Mult
     private final Item[] items;
     private final long tookInMillis;
 
-    private final RefCounted refCounted = LeakTracker.wrap(new AbstractRefCounted() {
-        @Override
-        protected void closeInternal() {
-            for (int i = 0; i < items.length; i++) {
-                Item item = items[i];
-                var r = item.response;
-                if (r != null) {
-                    r.decRef();
-                    items[i] = null;
-                }
-            }
-        }
-    });
+    private final RefCounted refCounted = LeakTracker.wrap(new SimpleRefCounted());
 
     public MultiSearchResponse(StreamInput in) throws IOException {
         super(in);
@@ -163,7 +151,22 @@ public class MultiSearchResponse extends ActionResponse implements Iterable<Mult
 
     @Override
     public boolean decRef() {
-        return refCounted.decRef();
+        if (refCounted.decRef()) {
+            deallocate();
+            return true;
+        }
+        return false;
+    }
+
+    private void deallocate() {
+        for (int i = 0; i < items.length; i++) {
+            Item item = items[i];
+            var r = item.response;
+            if (r != null) {
+                r.decRef();
+                items[i] = null;
+            }
+        }
     }
 
     @Override

+ 7 - 8
server/src/main/java/org/elasticsearch/action/search/SearchResponse.java

@@ -20,9 +20,9 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.util.concurrent.ConcurrentCollections;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
 import org.elasticsearch.common.xcontent.ChunkedToXContentObject;
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.core.SimpleRefCounted;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.rest.RestStatus;
 import org.elasticsearch.rest.action.RestActions;
@@ -83,12 +83,7 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
     private final Clusters clusters;
     private final long tookInMillis;
 
-    private final RefCounted refCounted = LeakTracker.wrap(new AbstractRefCounted() {
-        @Override
-        protected void closeInternal() {
-            hits.decRef();
-        }
-    });
+    private final RefCounted refCounted = LeakTracker.wrap(new SimpleRefCounted());
 
     public SearchResponse(StreamInput in) throws IOException {
         super(in);
@@ -232,7 +227,11 @@ public class SearchResponse extends ActionResponse implements ChunkedToXContentO
 
     @Override
     public boolean decRef() {
-        return refCounted.decRef();
+        if (refCounted.decRef()) {
+            hits.decRef();
+            return true;
+        }
+        return false;
     }
 
     @Override

+ 7 - 8
server/src/main/java/org/elasticsearch/action/search/SearchResponseSections.java

@@ -8,8 +8,8 @@
 
 package org.elasticsearch.action.search;
 
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.core.SimpleRefCounted;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.aggregations.InternalAggregations;
 import org.elasticsearch.search.profile.SearchProfileResults;
@@ -71,12 +71,7 @@ public class SearchResponseSections implements RefCounted {
         this.timedOut = timedOut;
         this.terminatedEarly = terminatedEarly;
         this.numReducePhases = numReducePhases;
-        refCounted = hits.getHits().length > 0 ? LeakTracker.wrap(new AbstractRefCounted() {
-            @Override
-            protected void closeInternal() {
-                hits.decRef();
-            }
-        }) : ALWAYS_REFERENCED;
+        refCounted = hits.getHits().length > 0 ? LeakTracker.wrap(new SimpleRefCounted()) : ALWAYS_REFERENCED;
     }
 
     public final SearchHits hits() {
@@ -112,7 +107,11 @@ public class SearchResponseSections implements RefCounted {
 
     @Override
     public boolean decRef() {
-        return refCounted.decRef();
+        if (refCounted.decRef()) {
+            hits.decRef();
+            return true;
+        }
+        return false;
     }
 
     @Override

+ 20 - 17
server/src/main/java/org/elasticsearch/search/SearchHit.java

@@ -25,10 +25,10 @@ import org.elasticsearch.common.util.Maps;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.XContentHelper;
 import org.elasticsearch.common.xcontent.support.XContentMapValues;
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.RestApiVersion;
+import org.elasticsearch.core.SimpleRefCounted;
 import org.elasticsearch.index.mapper.IgnoredFieldMapper;
 import org.elasticsearch.index.mapper.MapperService;
 import org.elasticsearch.index.mapper.SourceFieldMapper;
@@ -204,21 +204,7 @@ public final class SearchHit implements Writeable, ToXContentObject, RefCounted
         this.innerHits = innerHits;
         this.documentFields = documentFields;
         this.metaFields = metaFields;
-        this.refCounted = refCounted == null ? LeakTracker.wrap(new AbstractRefCounted() {
-            @Override
-            protected void closeInternal() {
-                if (SearchHit.this.innerHits != null) {
-                    for (SearchHits h : SearchHit.this.innerHits.values()) {
-                        h.decRef();
-                    }
-                    SearchHit.this.innerHits = null;
-                }
-                if (SearchHit.this.source instanceof RefCounted r) {
-                    r.decRef();
-                }
-                SearchHit.this.source = null;
-            }
-        }) : ALWAYS_REFERENCED;
+        this.refCounted = refCounted == null ? LeakTracker.wrap(new SimpleRefCounted()) : ALWAYS_REFERENCED;
     }
 
     public static SearchHit readFrom(StreamInput in, boolean pooled) throws IOException {
@@ -726,7 +712,24 @@ public final class SearchHit implements Writeable, ToXContentObject, RefCounted
 
     @Override
     public boolean decRef() {
-        return refCounted.decRef();
+        if (refCounted.decRef()) {
+            deallocate();
+            return true;
+        }
+        return false;
+    }
+
+    private void deallocate() {
+        if (SearchHit.this.innerHits != null) {
+            for (SearchHits h : SearchHit.this.innerHits.values()) {
+                h.decRef();
+            }
+            SearchHit.this.innerHits = null;
+        }
+        if (SearchHit.this.source instanceof RefCounted r) {
+            r.decRef();
+        }
+        SearchHit.this.source = null;
     }
 
     @Override

+ 15 - 12
server/src/main/java/org/elasticsearch/search/SearchHits.java

@@ -18,9 +18,9 @@ import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.lucene.Lucene;
 import org.elasticsearch.common.xcontent.ChunkedToXContent;
 import org.elasticsearch.common.xcontent.ChunkedToXContentHelper;
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.core.SimpleRefCounted;
 import org.elasticsearch.rest.action.search.RestSearchAction;
 import org.elasticsearch.transport.LeakTracker;
 import org.elasticsearch.xcontent.ToXContent;
@@ -76,16 +76,7 @@ public final class SearchHits implements Writeable, ChunkedToXContent, RefCounte
             sortFields,
             collapseField,
             collapseValues,
-            hits.length == 0 ? ALWAYS_REFERENCED : LeakTracker.wrap(new AbstractRefCounted() {
-                @Override
-                protected void closeInternal() {
-                    for (int i = 0; i < hits.length; i++) {
-                        assert hits[i] != null;
-                        hits[i].decRef();
-                        hits[i] = null;
-                    }
-                }
-            })
+            hits.length == 0 ? ALWAYS_REFERENCED : LeakTracker.wrap(new SimpleRefCounted())
         );
     }
 
@@ -249,7 +240,19 @@ public final class SearchHits implements Writeable, ChunkedToXContent, RefCounte
 
     @Override
     public boolean decRef() {
-        return refCounted.decRef();
+        if (refCounted.decRef()) {
+            deallocate();
+            return true;
+        }
+        return false;
+    }
+
+    private void deallocate() {
+        for (int i = 0; i < hits.length; i++) {
+            assert hits[i] != null;
+            hits[i].decRef();
+            hits[i] = null;
+        }
     }
 
     @Override

+ 14 - 8
server/src/main/java/org/elasticsearch/search/fetch/FetchSearchResult.java

@@ -10,8 +10,8 @@ package org.elasticsearch.search.fetch;
 
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.core.SimpleRefCounted;
 import org.elasticsearch.search.SearchHit;
 import org.elasticsearch.search.SearchHits;
 import org.elasticsearch.search.SearchPhaseResult;
@@ -30,12 +30,7 @@ public final class FetchSearchResult extends SearchPhaseResult {
 
     private ProfileResult profileResult;
 
-    private final RefCounted refCounted = LeakTracker.wrap(AbstractRefCounted.of(() -> {
-        if (hits != null) {
-            hits.decRef();
-            hits = null;
-        }
-    }));
+    private final RefCounted refCounted = LeakTracker.wrap(new SimpleRefCounted());
 
     public FetchSearchResult() {}
 
@@ -109,7 +104,18 @@ public final class FetchSearchResult extends SearchPhaseResult {
 
     @Override
     public boolean decRef() {
-        return refCounted.decRef();
+        if (refCounted.decRef()) {
+            deallocate();
+            return true;
+        }
+        return false;
+    }
+
+    private void deallocate() {
+        if (hits != null) {
+            hits.decRef();
+            hits = null;
+        }
     }
 
     @Override

+ 12 - 6
server/src/main/java/org/elasticsearch/search/fetch/QueryFetchSearchResult.java

@@ -10,8 +10,8 @@ package org.elasticsearch.search.fetch;
 
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.RefCounted;
+import org.elasticsearch.core.SimpleRefCounted;
 import org.elasticsearch.search.SearchPhaseResult;
 import org.elasticsearch.search.SearchShardTarget;
 import org.elasticsearch.search.internal.ShardSearchContextId;
@@ -41,10 +41,7 @@ public final class QueryFetchSearchResult extends SearchPhaseResult {
     private QueryFetchSearchResult(QuerySearchResult queryResult, FetchSearchResult fetchResult) {
         this.queryResult = queryResult;
         this.fetchResult = fetchResult;
-        refCounted = LeakTracker.wrap(AbstractRefCounted.of(() -> {
-            queryResult.decRef();
-            fetchResult.decRef();
-        }));
+        refCounted = LeakTracker.wrap(new SimpleRefCounted());
     }
 
     @Override
@@ -99,7 +96,16 @@ public final class QueryFetchSearchResult extends SearchPhaseResult {
 
     @Override
     public boolean decRef() {
-        return refCounted.decRef();
+        if (refCounted.decRef()) {
+            deallocate();
+            return true;
+        }
+        return false;
+    }
+
+    private void deallocate() {
+        queryResult.decRef();
+        fetchResult.decRef();
     }
 
     @Override

+ 7 - 3
server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java

@@ -15,11 +15,11 @@ import org.elasticsearch.common.io.stream.DelayableWriteable;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
 import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
-import org.elasticsearch.core.AbstractRefCounted;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.RefCounted;
 import org.elasticsearch.core.Releasable;
 import org.elasticsearch.core.Releasables;
+import org.elasticsearch.core.SimpleRefCounted;
 import org.elasticsearch.search.DocValueFormat;
 import org.elasticsearch.search.RescoreDocIds;
 import org.elasticsearch.search.SearchPhaseResult;
@@ -103,7 +103,7 @@ public final class QuerySearchResult extends SearchPhaseResult {
         isNull = false;
         setShardSearchRequest(shardSearchRequest);
         this.toRelease = new ArrayList<>();
-        this.refCounted = LeakTracker.wrap(AbstractRefCounted.of(() -> Releasables.close(toRelease)));
+        this.refCounted = LeakTracker.wrap(new SimpleRefCounted());
     }
 
     private QuerySearchResult(boolean isNull) {
@@ -489,7 +489,11 @@ public final class QuerySearchResult extends SearchPhaseResult {
     @Override
     public boolean decRef() {
         if (refCounted != null) {
-            return refCounted.decRef();
+            if (refCounted.decRef()) {
+                Releasables.close(toRelease);
+                return true;
+            }
+            return false;
         }
         return super.decRef();
     }