فهرست منبع

Fix aggregation memory leak for CCS (#78404)

When a CCS search is proxied, the memory for the aggregations on the
proxy node would not be freed.

Now only use the non-copying byte referencing version on the coordinating node,
which itself ensures that memory is freed by calling `consumeAggs`.
Henning Andersen 4 سال پیش
والد
کامیت
80792a1a82

+ 146 - 0
server/src/internalClusterTest/java/org/elasticsearch/search/ccs/CrossClusterSearchLeakIT.java

@@ -0,0 +1,146 @@
+/*
+ * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
+ * or more contributor license agreements. Licensed under the Elastic License
+ * 2.0 and the Server Side Public License, v 1; you may not use this file except
+ * in compliance with, at your election, the Elastic License 2.0 or the Server
+ * Side Public License, v 1.
+ */
+
+package org.elasticsearch.search.ccs;
+
+import org.elasticsearch.action.ActionFuture;
+import org.elasticsearch.action.search.ClearScrollRequest;
+import org.elasticsearch.action.search.SearchRequest;
+import org.elasticsearch.action.search.SearchResponse;
+import org.elasticsearch.client.Client;
+import org.elasticsearch.cluster.metadata.IndexMetadata;
+import org.elasticsearch.cluster.node.DiscoveryNode;
+import org.elasticsearch.common.settings.Settings;
+import org.elasticsearch.core.TimeValue;
+import org.elasticsearch.index.query.MatchAllQueryBuilder;
+import org.elasticsearch.search.aggregations.bucket.terms.Terms;
+import org.elasticsearch.search.builder.SearchSourceBuilder;
+import org.elasticsearch.test.AbstractMultiClustersTestCase;
+import org.elasticsearch.test.InternalTestCluster;
+import org.elasticsearch.transport.TransportService;
+import org.hamcrest.Matchers;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.StreamSupport;
+
+import static org.elasticsearch.search.aggregations.AggregationBuilders.terms;
+import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked;
+import static org.hamcrest.Matchers.equalTo;
+
+public class CrossClusterSearchLeakIT extends AbstractMultiClustersTestCase {
+
+    @Override
+    protected Collection<String> remoteClusterAlias() {
+        return List.of("cluster_a");
+    }
+
+    @Override
+    protected boolean reuseClusters() {
+        return false;
+    }
+
+    private int indexDocs(Client client, String field, String index) {
+        int numDocs = between(1, 200);
+        for (int i = 0; i < numDocs; i++) {
+            client.prepareIndex(index).setSource(field, "v" + i).get();
+        }
+        client.admin().indices().prepareRefresh(index).get();
+        return numDocs;
+    }
+
+    /**
+     * This test validates that we do not leak any memory when running CCS in various modes, actual validation is done by test framework
+     * (leak detection)
+     * <ul>
+     *     <li>proxy vs non-proxy</li>
+     *     <li>single-phase query-fetch or multi-phase</li>
+     *     <li>minimize roundtrip vs not</li>
+     *     <li>scroll vs no scroll</li>
+     * </ul>
+     */
+    public void testSearch() throws Exception {
+        assertAcked(client(LOCAL_CLUSTER).admin().indices().prepareCreate("demo")
+            .setMapping("f", "type=keyword")
+            .setSettings(Settings.builder().put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3))));
+        indexDocs(client(LOCAL_CLUSTER), "ignored", "demo");
+        final InternalTestCluster remoteCluster = cluster("cluster_a");
+        int minRemotes = between(2, 5);
+        remoteCluster.ensureAtLeastNumDataNodes(minRemotes);
+        List<String> remoteDataNodes = StreamSupport.stream(remoteCluster.clusterService().state().nodes().spliterator(), false)
+            .filter(DiscoveryNode::canContainData)
+            .map(DiscoveryNode::getName)
+            .collect(Collectors.toList());
+        assertThat(remoteDataNodes.size(), Matchers.greaterThanOrEqualTo(minRemotes));
+        List<String> seedNodes = randomSubsetOf(between(1, remoteDataNodes.size() - 1), remoteDataNodes);
+        disconnectFromRemoteClusters();
+        configureRemoteCluster("cluster_a", seedNodes);
+        final Settings.Builder allocationFilter = Settings.builder();
+        if (rarely()) {
+            allocationFilter.put("index.routing.allocation.include._name", String.join(",", seedNodes));
+        } else {
+            // Provoke using proxy connections
+            allocationFilter.put("index.routing.allocation.exclude._name", String.join(",", seedNodes));
+        }
+        assertAcked(client("cluster_a").admin().indices().prepareCreate("prod")
+            .setMapping("f", "type=keyword")
+            .setSettings(Settings.builder().put(allocationFilter.build())
+                .put(IndexMetadata.SETTING_NUMBER_OF_REPLICAS, 0).put(IndexMetadata.SETTING_NUMBER_OF_SHARDS, between(1, 3))));
+        assertFalse(client("cluster_a").admin().cluster().prepareHealth("prod")
+            .setWaitForYellowStatus().setTimeout(TimeValue.timeValueSeconds(10)).get().isTimedOut());
+        int docs = indexDocs(client("cluster_a"), "f", "prod");
+
+        List<ActionFuture<SearchResponse>> futures = new ArrayList<>();
+        for (int i = 0; i < 10; ++i) {
+            String[] indices = randomBoolean() ? new String[] { "demo", "cluster_a:prod" } : new String[] { "cluster_a:prod" };
+            final SearchRequest searchRequest = new SearchRequest(indices);
+            searchRequest.allowPartialSearchResults(false);
+            boolean scroll = randomBoolean();
+            searchRequest.source(new SearchSourceBuilder().query(new MatchAllQueryBuilder())
+                .aggregation(terms("f").field("f").size(docs + between(scroll ? 1 : 0, 10))).size(between(0, 1000)));
+            if (scroll) {
+                searchRequest.scroll("30s");
+            }
+            searchRequest.setCcsMinimizeRoundtrips(rarely());
+            futures.add(client(LOCAL_CLUSTER).search(searchRequest));
+        }
+
+        for (ActionFuture<SearchResponse> future : futures) {
+            SearchResponse searchResponse = future.get();
+            if (searchResponse.getScrollId() != null) {
+                ClearScrollRequest clearScrollRequest = new ClearScrollRequest();
+                clearScrollRequest.scrollIds(List.of(searchResponse.getScrollId()));
+                client(LOCAL_CLUSTER).clearScroll(clearScrollRequest).get();
+            }
+
+            Terms terms = searchResponse.getAggregations().get("f");
+            assertThat(terms.getBuckets().size(), equalTo(docs));
+            for (Terms.Bucket bucket : terms.getBuckets()) {
+                assertThat(bucket.getDocCount(), equalTo(1L));
+            }
+        }
+    }
+
+    @Override
+    protected void configureRemoteCluster(String clusterAlias, Collection<String> seedNodes) throws Exception {
+        if (rarely()) {
+            super.configureRemoteCluster(clusterAlias, seedNodes);
+        } else {
+            final Settings.Builder settings = Settings.builder();
+            final String seedNode = randomFrom(seedNodes);
+            final TransportService transportService = cluster(clusterAlias).getInstance(TransportService.class, seedNode);
+            final String seedAddress = transportService.boundAddress().publishAddress().toString();
+
+            settings.put("cluster.remote." + clusterAlias + ".mode", "proxy");
+            settings.put("cluster.remote." + clusterAlias + ".proxy_address", seedAddress);
+            client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();
+        }
+    }
+}

+ 1 - 1
server/src/main/java/org/elasticsearch/action/search/SearchTransportService.java

@@ -138,7 +138,7 @@ public class SearchTransportService {
         // we optimize this and expect a QueryFetchSearchResult if we only have a single shard in the search request
         // this used to be the QUERY_AND_FETCH which doesn't exist anymore.
         final boolean fetchDocuments = request.numberOfShards() == 1;
-        Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : QuerySearchResult::new;
+        Writeable.Reader<SearchPhaseResult> reader = fetchDocuments ? QueryFetchSearchResult::new : in -> new QuerySearchResult(in, true);
 
         final ActionListener<? super SearchPhaseResult> handler = responseWrapper.apply(connection, listener);
         transportService.sendChildRequest(connection, QUERY_ACTION_NAME, request, task,

+ 20 - 8
server/src/main/java/org/elasticsearch/common/io/stream/DelayableWriteable.java

@@ -9,6 +9,7 @@
 package org.elasticsearch.common.io.stream;
 
 import org.elasticsearch.Version;
+import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.bytes.ReleasableBytesReference;
 import org.elasticsearch.core.Releasable;
 
@@ -50,6 +51,12 @@ public abstract class DelayableWriteable<T extends Writeable> implements Writeab
         return new Serialized<>(reader, in.getVersion(), in.namedWriteableRegistry(), in.readReleasableBytesReference());
     }
 
+    public static <T extends Writeable> DelayableWriteable<T> referencing(Writeable.Reader<T> reader, StreamInput in) throws IOException {
+        try (ReleasableBytesReference serialized = in.readReleasableBytesReference()) {
+            return new Referencing<>(deserialize(reader, in.getVersion(), in.namedWriteableRegistry(), serialized));
+        }
+    }
+
     private DelayableWriteable() {}
 
     /**
@@ -67,7 +74,7 @@ public abstract class DelayableWriteable<T extends Writeable> implements Writeab
      * {@code true} if the {@linkplain Writeable} is being stored in
      * serialized form, {@code false} otherwise.
      */
-    abstract boolean isSerialized();
+    public abstract boolean isSerialized();
 
     /**
      * Returns the serialized size of the inner {@link Writeable}.
@@ -104,7 +111,7 @@ public abstract class DelayableWriteable<T extends Writeable> implements Writeab
         }
 
         @Override
-        boolean isSerialized() {
+        public boolean isSerialized() {
             return false;
         }
 
@@ -169,11 +176,7 @@ public abstract class DelayableWriteable<T extends Writeable> implements Writeab
         @Override
         public T expand() {
             try {
-                try (StreamInput in = registry == null ?
-                        serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
-                    in.setVersion(serializedAtVersion);
-                    return reader.read(in);
-                }
+                return deserialize(reader, serializedAtVersion, registry, serialized);
             } catch (IOException e) {
                 throw new RuntimeException("unexpected error expanding serialized delayed writeable", e);
             }
@@ -185,7 +188,7 @@ public abstract class DelayableWriteable<T extends Writeable> implements Writeab
         }
 
         @Override
-        boolean isSerialized() {
+        public boolean isSerialized() {
             return true;
         }
 
@@ -214,6 +217,15 @@ public abstract class DelayableWriteable<T extends Writeable> implements Writeab
         }
     }
 
+    private static <T> T deserialize(Reader<T> reader, Version serializedAtVersion, NamedWriteableRegistry registry,
+                                     BytesReference serialized) throws IOException {
+        try (StreamInput in =
+                 registry == null ? serialized.streamInput() : new NamedWriteableAwareStreamInput(serialized.streamInput(), registry)) {
+            in.setVersion(serializedAtVersion);
+            return reader.read(in);
+        }
+    }
+
     private static class CountingStreamOutput extends StreamOutput {
         long size = 0;
 

+ 21 - 3
server/src/main/java/org/elasticsearch/search/query/QuerySearchResult.java

@@ -33,7 +33,6 @@ import static org.elasticsearch.common.lucene.Lucene.readTopDocs;
 import static org.elasticsearch.common.lucene.Lucene.writeTopDocs;
 
 public final class QuerySearchResult extends SearchPhaseResult {
-
     private int from;
     private int size;
     private TopDocsAndMaxScore topDocsAndMaxScore;
@@ -65,6 +64,15 @@ public final class QuerySearchResult extends SearchPhaseResult {
     }
 
     public QuerySearchResult(StreamInput in) throws IOException {
+        this(in, false);
+    }
+
+    /**
+     * Read the object, but using a delayed aggregations field when delayedAggregations=true. Using this, the caller must ensure that
+     * either `consumeAggs` or `releaseAggs` is called if `hasAggs() == true`.
+     * @param delayedAggregations whether to use delayed aggregations or not
+     */
+    public QuerySearchResult(StreamInput in, boolean delayedAggregations) throws IOException {
         super(in);
         if (in.getVersion().onOrAfter(Version.V_7_7_0)) {
             isNull = in.readBoolean();
@@ -73,7 +81,7 @@ public final class QuerySearchResult extends SearchPhaseResult {
         }
         if (isNull == false) {
             ShardSearchContextId id = new ShardSearchContextId(in);
-            readFromWithId(id, in);
+            readFromWithId(id, in, delayedAggregations);
         }
     }
 
@@ -316,6 +324,10 @@ public final class QuerySearchResult extends SearchPhaseResult {
     }
 
     public void readFromWithId(ShardSearchContextId id, StreamInput in) throws IOException {
+        readFromWithId(id, in, false);
+    }
+
+    private void readFromWithId(ShardSearchContextId id, StreamInput in, boolean delayedAggregations) throws IOException {
         this.contextId = id;
         from = in.readVInt();
         size = in.readVInt();
@@ -333,7 +345,11 @@ public final class QuerySearchResult extends SearchPhaseResult {
         boolean success = false;
         try {
             if (hasAggs) {
-                aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in);
+                if (delayedAggregations) {
+                    aggregations = DelayableWriteable.delayed(InternalAggregations::readFrom, in);
+                } else {
+                    aggregations = DelayableWriteable.referencing(InternalAggregations::readFrom, in);
+                }
             }
             if (in.readBoolean()) {
                 suggest = new Suggest(in);
@@ -359,6 +375,8 @@ public final class QuerySearchResult extends SearchPhaseResult {
 
     @Override
     public void writeTo(StreamOutput out) throws IOException {
+        // we do not know that it is being sent over transport, but this at least protects all writes from happening, including sending.
+        assert aggregations == null || aggregations.isSerialized() == false : "cannot send serialized version since it will leak";
         if (out.getVersion().onOrAfter(Version.V_7_7_0)) {
             out.writeBoolean(isNull);
         }

+ 8 - 4
server/src/test/java/org/elasticsearch/common/io/stream/DelayableWriteableTests.java

@@ -134,14 +134,12 @@ public class DelayableWriteableTests extends ESTestCase {
     public void testRoundTripFromDelayedFromOldVersion() throws IOException {
         Example e = new Example(randomAlphaOfLength(5));
         DelayableWriteable<Example> original = roundTrip(DelayableWriteable.referencing(e), Example::new, randomOldVersion());
-        assertTrue(original.isSerialized());
         roundTripTestCase(original, Example::new);
     }
 
     public void testRoundTripFromDelayedFromOldVersionWithNamedWriteable() throws IOException {
         NamedHolder n = new NamedHolder(new Example(randomAlphaOfLength(5)));
         DelayableWriteable<NamedHolder> original = roundTrip(DelayableWriteable.referencing(n), NamedHolder::new, randomOldVersion());
-        assertTrue(original.isSerialized());
         roundTripTestCase(original, NamedHolder::new);
     }
 
@@ -160,14 +158,20 @@ public class DelayableWriteableTests extends ESTestCase {
 
     private <T extends Writeable> void roundTripTestCase(DelayableWriteable<T> original, Writeable.Reader<T> reader) throws IOException {
         DelayableWriteable<T> roundTripped = roundTrip(original, reader, Version.CURRENT);
-        assertTrue(roundTripped.isSerialized());
         assertThat(roundTripped.expand(), equalTo(original.expand()));
     }
 
     private <T extends Writeable> DelayableWriteable<T> roundTrip(DelayableWriteable<T> original,
             Writeable.Reader<T> reader, Version version) throws IOException {
-        return copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
+        DelayableWriteable<T> delayed = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
             in -> DelayableWriteable.delayed(reader, in), version);
+        assertTrue(delayed.isSerialized());
+
+        DelayableWriteable<T> referencing = copyInstance(original, writableRegistry(), (out, d) -> d.writeTo(out),
+            in -> DelayableWriteable.referencing(reader, in), version);
+        assertFalse(referencing.isSerialized());
+
+        return randomFrom(delayed, referencing);
     }
 
     @Override

+ 7 - 1
server/src/test/java/org/elasticsearch/search/query/QuerySearchResultTests.java

@@ -33,6 +33,8 @@ import org.elasticsearch.search.suggest.SuggestTests;
 import org.elasticsearch.test.ESTestCase;
 
 import static java.util.Collections.emptyList;
+import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.nullValue;
 
 public class QuerySearchResultTests extends ESTestCase {
 
@@ -68,8 +70,10 @@ public class QuerySearchResultTests extends ESTestCase {
 
     public void testSerialization() throws Exception {
         QuerySearchResult querySearchResult = createTestInstance();
+        boolean delayed = randomBoolean();
         QuerySearchResult deserialized = copyWriteable(querySearchResult, namedWriteableRegistry,
-            QuerySearchResult::new, Version.CURRENT);
+            delayed ? in -> new QuerySearchResult(in, true) : QuerySearchResult::new,
+            Version.CURRENT);
         assertEquals(querySearchResult.getContextId().getId(), deserialized.getContextId().getId());
         assertNull(deserialized.getSearchShardTarget());
         assertEquals(querySearchResult.topDocs().maxScore, deserialized.topDocs().maxScore, 0f);
@@ -78,9 +82,11 @@ public class QuerySearchResultTests extends ESTestCase {
         assertEquals(querySearchResult.size(), deserialized.size());
         assertEquals(querySearchResult.hasAggs(), deserialized.hasAggs());
         if (deserialized.hasAggs()) {
+            assertThat(deserialized.aggregations().isSerialized(), is(delayed));
             Aggregations aggs = querySearchResult.consumeAggs();
             Aggregations deserializedAggs = deserialized.consumeAggs();
             assertEquals(aggs.asList(), deserializedAggs.asList());
+            assertThat(deserialized.aggregations(), is(nullValue()));
         }
         assertEquals(querySearchResult.terminatedEarly(), deserialized.terminatedEarly());
     }

+ 2 - 0
test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java

@@ -134,6 +134,8 @@ public abstract class AbstractMultiClustersTestCase extends ESTestCase {
         for (String clusterAlias : clusterAliases) {
             if (clusterAlias.equals(LOCAL_CLUSTER) == false) {
                 settings.putNull("cluster.remote." + clusterAlias + ".seeds");
+                settings.putNull("cluster.remote." + clusterAlias + ".mode");
+                settings.putNull("cluster.remote." + clusterAlias + ".proxy_address");
             }
         }
         client().admin().cluster().prepareUpdateSettings().setPersistentSettings(settings).get();