瀏覽代碼

Write async response directly to XContent to reduce memory usage (#73707)

This change tries to write an async response directly to XContent in 
Base64 to avoid using multiple buffers.

Relates to #67594
Nhat Nguyen 4 年之前
父節點
當前提交
6b7fea0b42

+ 33 - 26
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/async/AsyncTaskIndexService.java

@@ -21,17 +21,17 @@ import org.elasticsearch.client.OriginSettingClient;
 import org.elasticsearch.cluster.metadata.IndexMetadata;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.TriFunction;
-import org.elasticsearch.common.bytes.BytesReference;
 import org.elasticsearch.common.collect.Tuple;
 import org.elasticsearch.common.io.stream.ByteBufferStreamInput;
-import org.elasticsearch.common.io.stream.BytesStreamOutput;
 import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
 import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
+import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.Writeable;
 import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.common.util.concurrent.ThreadContext;
 import org.elasticsearch.common.xcontent.XContentBuilder;
+import org.elasticsearch.common.xcontent.XContentFactory;
 import org.elasticsearch.common.xcontent.XContentType;
 import org.elasticsearch.indices.SystemIndexDescriptor;
 import org.elasticsearch.search.fetch.subphase.FetchSourceContext;
@@ -45,11 +45,11 @@ import org.elasticsearch.xpack.core.security.authc.Authentication;
 import org.elasticsearch.xpack.core.security.authc.support.AuthenticationContextSerializer;
 
 import java.io.IOException;
+import java.io.OutputStream;
 import java.io.UncheckedIOException;
 import java.nio.ByteBuffer;
 import java.util.Base64;
 import java.util.Collections;
-import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
@@ -181,15 +181,23 @@ public final class AsyncTaskIndexService<R extends AsyncResponse<R>> {
                                Map<String, String> headers,
                                R response,
                                ActionListener<IndexResponse> listener) throws IOException {
-        Map<String, Object> source = new HashMap<>();
-        source.put(HEADERS_FIELD, headers);
-        source.put(EXPIRATION_TIME_FIELD, response.getExpirationTime());
-        source.put(RESULT_FIELD, encodeResponse(response));
-        IndexRequest indexRequest = new IndexRequest(index)
-            .create(true)
-            .id(docId)
-            .source(source, XContentType.JSON);
-        clientWithOrigin.index(indexRequest, listener);
+        try {
+            // TODO: Integrate with circuit breaker
+            final XContentBuilder source = XContentFactory.jsonBuilder()
+                .startObject()
+                .field(HEADERS_FIELD, headers)
+                .field(EXPIRATION_TIME_FIELD, response.getExpirationTime())
+                .directFieldAsBase64(RESULT_FIELD, os -> writeResponse(response, os))
+                .endObject();
+
+            final IndexRequest indexRequest = new IndexRequest(index)
+                .create(true)
+                .id(docId)
+                .source(source);
+            clientWithOrigin.index(indexRequest, listener);
+        } catch (Exception e) {
+            listener.onFailure(e);
+        }
     }
 
     /**
@@ -200,16 +208,19 @@ public final class AsyncTaskIndexService<R extends AsyncResponse<R>> {
                             R response,
                             ActionListener<UpdateResponse> listener) {
         try {
-            Map<String, Object> source = new HashMap<>();
-            source.put(RESPONSE_HEADERS_FIELD, responseHeaders);
-            source.put(RESULT_FIELD, encodeResponse(response));
+            // TODO: Integrate with circuit breaker
+            final XContentBuilder source = XContentFactory.jsonBuilder()
+                .startObject()
+                .field(RESPONSE_HEADERS_FIELD, responseHeaders)
+                .directFieldAsBase64(RESULT_FIELD, os -> writeResponse(response, os))
+                .endObject();
             UpdateRequest request = new UpdateRequest()
                 .index(index)
                 .id(docId)
-                .doc(source, XContentType.JSON)
+                .doc(source)
                 .retryOnConflict(5);
             clientWithOrigin.update(request, listener);
-        } catch(Exception e) {
+        } catch (Exception e) {
             listener.onFailure(e);
         }
     }
@@ -452,21 +463,17 @@ public final class AsyncTaskIndexService<R extends AsyncResponse<R>> {
         return origin.canAccessResourcesOf(current);
     }
 
-    /**
-     * Encode the provided response in a binary form using base64 encoding.
-     */
-    String encodeResponse(R response) throws IOException {
-        try (BytesStreamOutput out = new BytesStreamOutput()) {
-            Version.writeVersion(Version.CURRENT, out);
-            response.writeTo(out);
-            return Base64.getEncoder().encodeToString(BytesReference.toBytes(out.bytes()));
-        }
+    private void writeResponse(R response, OutputStream os) throws IOException {
+        final OutputStreamStreamOutput out = new OutputStreamStreamOutput(os);
+        Version.writeVersion(Version.CURRENT, out);
+        response.writeTo(out);
     }
 
     /**
      * Decode the provided base-64 bytes into a {@link AsyncSearchResponse}.
      */
     R decodeResponse(String value) throws IOException {
+        // TODO: Integrate with the circuit breaker
         try (ByteBufferStreamInput buf = new ByteBufferStreamInput(ByteBuffer.wrap(Base64.getDecoder().decode(value)))) {
             try (StreamInput in = new NamedWriteableAwareStreamInput(buf, registry)) {
                 in.setVersion(Version.readVersion(in));

+ 54 - 5
x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/async/AsyncSearchIndexServiceTests.java

@@ -6,14 +6,20 @@
  */
 package org.elasticsearch.xpack.core.async;
 
+import org.elasticsearch.action.DocWriteResponse;
+import org.elasticsearch.action.index.IndexResponse;
+import org.elasticsearch.action.support.PlainActionFuture;
+import org.elasticsearch.action.update.UpdateResponse;
 import org.elasticsearch.cluster.service.ClusterService;
 import org.elasticsearch.common.io.stream.StreamInput;
 import org.elasticsearch.common.io.stream.StreamOutput;
+import org.elasticsearch.tasks.TaskId;
 import org.elasticsearch.test.ESSingleNodeTestCase;
 import org.elasticsearch.transport.TransportService;
 import org.junit.Before;
 
 import java.io.IOException;
+import java.util.Map;
 import java.util.Objects;
 
 import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN;
@@ -66,6 +72,14 @@ public class AsyncSearchIndexServiceTests extends ESSingleNodeTestCase {
         public int hashCode() {
             return Objects.hash(test, expirationTimeMillis);
         }
+
+        @Override
+        public String toString() {
+            return "TestAsyncResponse{" +
+                "test='" + test + '\'' +
+                ", expirationTimeMillis=" + expirationTimeMillis +
+                '}';
+        }
     }
 
     @Before
@@ -77,11 +91,46 @@ public class AsyncSearchIndexServiceTests extends ESSingleNodeTestCase {
     }
 
     public void testEncodeSearchResponse() throws IOException {
-        for (int i = 0; i < 10; i++) {
-            TestAsyncResponse response = new TestAsyncResponse(randomAlphaOfLength(10), randomLong());
-            String encoded = indexService.encodeResponse(response);
-            TestAsyncResponse same = indexService.decodeResponse(encoded);
-            assertThat(same, equalTo(response));
+        final int iterations = iterations(1, 20);
+        for (int i = 0; i < iterations; i++) {
+            long expirationTime = randomLong();
+            String testMessage = randomAlphaOfLength(10);
+            TestAsyncResponse initialResponse = new TestAsyncResponse(testMessage, expirationTime);
+            AsyncExecutionId executionId = new AsyncExecutionId(
+                Long.toString(randomNonNegativeLong()),
+                new TaskId(randomAlphaOfLength(10), randomNonNegativeLong()));
+
+            PlainActionFuture<IndexResponse> createFuture = new PlainActionFuture<>();
+            indexService.createResponse(executionId.getDocId(), Map.of(), initialResponse, createFuture);
+            assertThat(createFuture.actionGet().getResult(), equalTo(DocWriteResponse.Result.CREATED));
+
+            if (randomBoolean()) {
+                PlainActionFuture<TestAsyncResponse> getFuture = new PlainActionFuture<>();
+                indexService.getResponse(executionId, randomBoolean(), getFuture);
+                assertThat(getFuture.actionGet(), equalTo(initialResponse));
+            }
+
+            int updates = randomIntBetween(1, 5);
+            for (int u = 0; u < updates; u++) {
+                if (randomBoolean()) {
+                    testMessage = randomAlphaOfLength(10);
+                    TestAsyncResponse updateResponse = new TestAsyncResponse(testMessage, randomLong());
+                    PlainActionFuture<UpdateResponse> updateFuture = new PlainActionFuture<>();
+                    indexService.updateResponse(executionId.getDocId(), Map.of(), updateResponse, updateFuture);
+                    updateFuture.actionGet();
+                } else {
+                    expirationTime = randomLong();
+                    PlainActionFuture<UpdateResponse> updateFuture = new PlainActionFuture<>();
+                    indexService.updateExpirationTime(executionId.getDocId(), expirationTime, updateFuture);
+                    updateFuture.actionGet();
+                }
+                if (randomBoolean()) {
+                    PlainActionFuture<TestAsyncResponse> getFuture = new PlainActionFuture<>();
+                    indexService.getResponse(executionId, randomBoolean(), getFuture);
+                    assertThat(getFuture.actionGet().test, equalTo(testMessage));
+                    assertThat(getFuture.actionGet().expirationTimeMillis, equalTo(expirationTime));
+                }
+            }
         }
     }
 }