瀏覽代碼

[ML] Ensuring only a single request executor object is created (#133424) (#133670)

* Ensuring only a single request executor thread is started

* Reverting test changes

* Update docs/changelog/133424.yaml
Jonathan Buttner 1 月之前
父節點
當前提交
0dc6305692

+ 5 - 0
docs/changelog/133424.yaml

@@ -0,0 +1,5 @@
+pr: 133424
+summary: Ensuring only a single request executor object is created
+area: Machine Learning
+type: bug
+issues: []

+ 29 - 27
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSender.java

@@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.action.support.ContextPreservingActionListener;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.Nullable;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceServiceResults;
@@ -42,32 +41,39 @@ public class HttpRequestSender implements Sender {
      * A helper class for constructing a {@link HttpRequestSender}.
      */
     public static class Factory {
-        private final ServiceComponents serviceComponents;
-        private final HttpClientManager httpClientManager;
-        private final ClusterService clusterService;
-        private final RequestSender requestSender;
+        private final HttpRequestSender httpRequestSender;
 
         public Factory(ServiceComponents serviceComponents, HttpClientManager httpClientManager, ClusterService clusterService) {
-            this.serviceComponents = Objects.requireNonNull(serviceComponents);
-            this.httpClientManager = Objects.requireNonNull(httpClientManager);
-            this.clusterService = Objects.requireNonNull(clusterService);
+            Objects.requireNonNull(serviceComponents);
+            Objects.requireNonNull(clusterService);
+            Objects.requireNonNull(httpClientManager);
 
-            requestSender = new RetryingHttpSender(
-                this.httpClientManager.getHttpClient(),
+            var requestSender = new RetryingHttpSender(
+                httpClientManager.getHttpClient(),
                 serviceComponents.throttlerManager(),
                 new RetrySettings(serviceComponents.settings(), clusterService),
                 serviceComponents.threadPool()
             );
-        }
 
-        public Sender createSender() {
-            return new HttpRequestSender(
+            var startCompleted = new CountDownLatch(1);
+            var service = new RequestExecutorService(
                 serviceComponents.threadPool(),
-                httpClientManager,
-                clusterService,
-                serviceComponents.settings(),
+                startCompleted,
+                new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
                 requestSender
             );
+
+            httpRequestSender = new HttpRequestSender(
+                serviceComponents.threadPool(),
+                httpClientManager,
+                requestSender,
+                service,
+                startCompleted
+            );
+        }
+
+        public Sender createSender() {
+            return httpRequestSender;
         }
     }
 
@@ -75,27 +81,23 @@ public class HttpRequestSender implements Sender {
 
     private final ThreadPool threadPool;
     private final HttpClientManager manager;
-    private final RequestExecutor service;
     private final AtomicBoolean started = new AtomicBoolean(false);
-    private final CountDownLatch startCompleted = new CountDownLatch(1);
     private final RequestSender requestSender;
+    private final RequestExecutor service;
+    private final CountDownLatch startCompleted;
 
     private HttpRequestSender(
         ThreadPool threadPool,
         HttpClientManager httpClientManager,
-        ClusterService clusterService,
-        Settings settings,
-        RequestSender requestSender
+        RequestSender requestSender,
+        RequestExecutor service,
+        CountDownLatch startCompleted
     ) {
         this.threadPool = Objects.requireNonNull(threadPool);
         this.manager = Objects.requireNonNull(httpClientManager);
         this.requestSender = Objects.requireNonNull(requestSender);
-        service = new RequestExecutorService(
-            threadPool,
-            startCompleted,
-            new RequestExecutorServiceSettings(settings, clusterService),
-            requestSender
-        );
+        this.service = Objects.requireNonNull(service);
+        this.startCompleted = Objects.requireNonNull(startCompleted);
     }
 
     /**

+ 37 - 28
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSender.java

@@ -11,7 +11,6 @@ import org.apache.logging.log4j.Logger;
 import org.elasticsearch.ElasticsearchException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.cluster.service.ClusterService;
-import org.elasticsearch.common.settings.Settings;
 import org.elasticsearch.core.TimeValue;
 import org.elasticsearch.inference.InferenceServiceResults;
 import org.elasticsearch.threadpool.ThreadPool;
@@ -38,31 +37,46 @@ public class AmazonBedrockRequestSender implements Sender {
 
     public static class Factory {
         private final ServiceComponents serviceComponents;
-        private final ClusterService clusterService;
+        private final AmazonBedrockRequestExecutorService executorService;
+        private final CountDownLatch startCompleted = new CountDownLatch(1);
+        private final AmazonBedrockRequestSender bedrockRequestSender;
 
         public Factory(ServiceComponents serviceComponents, ClusterService clusterService) {
-            this.serviceComponents = Objects.requireNonNull(serviceComponents);
-            this.clusterService = Objects.requireNonNull(clusterService);
-        }
-
-        public Sender createSender() {
-            var clientCache = new AmazonBedrockInferenceClientCache(
-                (model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
-                Clock.systemUTC()
+            this(
+                serviceComponents,
+                clusterService,
+                new AmazonBedrockExecuteOnlyRequestSender(
+                    new AmazonBedrockInferenceClientCache(
+                        (model, timeout) -> AmazonBedrockInferenceClient.create(model, timeout, serviceComponents.threadPool()),
+                        Clock.systemUTC()
+                    ),
+                    serviceComponents.throttlerManager()
+                )
             );
-            return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager()));
         }
 
-        Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) {
-            var sender = new AmazonBedrockRequestSender(
+        public Factory(
+            ServiceComponents serviceComponents,
+            ClusterService clusterService,
+            AmazonBedrockExecuteOnlyRequestSender requestSender
+        ) {
+            this.serviceComponents = Objects.requireNonNull(serviceComponents);
+            Objects.requireNonNull(clusterService);
+
+            executorService = new AmazonBedrockRequestExecutorService(
                 serviceComponents.threadPool(),
-                clusterService,
-                serviceComponents.settings(),
-                Objects.requireNonNull(requestSender)
+                startCompleted,
+                new RequestExecutorServiceSettings(serviceComponents.settings(), clusterService),
+                requestSender
             );
+
+            bedrockRequestSender = new AmazonBedrockRequestSender(serviceComponents.threadPool(), executorService, startCompleted);
+        }
+
+        public Sender createSender() {
             // ensure this is started
-            sender.start();
-            return sender;
+            bedrockRequestSender.start();
+            return bedrockRequestSender;
         }
     }
 
@@ -71,21 +85,16 @@ public class AmazonBedrockRequestSender implements Sender {
     private final ThreadPool threadPool;
     private final AmazonBedrockRequestExecutorService executorService;
     private final AtomicBoolean started = new AtomicBoolean(false);
-    private final CountDownLatch startCompleted = new CountDownLatch(1);
+    private final CountDownLatch startCompleted;
 
     protected AmazonBedrockRequestSender(
         ThreadPool threadPool,
-        ClusterService clusterService,
-        Settings settings,
-        AmazonBedrockExecuteOnlyRequestSender requestSender
+        AmazonBedrockRequestExecutorService executorService,
+        CountDownLatch startCompleted
     ) {
         this.threadPool = Objects.requireNonNull(threadPool);
-        executorService = new AmazonBedrockRequestExecutorService(
-            threadPool,
-            startCompleted,
-            new RequestExecutorServiceSettings(settings, clusterService),
-            requestSender
-        );
+        this.executorService = Objects.requireNonNull(executorService);
+        this.startCompleted = Objects.requireNonNull(startCompleted);
     }
 
     @Override

+ 22 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java

@@ -58,6 +58,7 @@ import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
 import static org.mockito.ArgumentMatchers.anyString;
 import static org.mockito.Mockito.any;
 import static org.mockito.Mockito.doAnswer;
@@ -90,6 +91,27 @@ public class HttpRequestSenderTests extends ESTestCase {
         webServer.close();
     }
 
+    public void testCreateSender_ReturnsSameRequestExecutorInstance() {
+        var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
+
+        var sender1 = createSender(senderFactory);
+        var sender2 = createSender(senderFactory);
+
+        assertThat(sender1, instanceOf(HttpRequestSender.class));
+        assertThat(sender2, instanceOf(HttpRequestSender.class));
+        assertThat(sender1, sameInstance(sender2));
+    }
+
+    public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
+        var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+            sender.start();
+            sender.start();
+        }
+    }
+
     public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception {
         var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty());
 

+ 41 - 8
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/client/AmazonBedrockRequestSenderTests.java

@@ -37,7 +37,9 @@ import static org.elasticsearch.xpack.core.inference.results.TextEmbeddingFloatR
 import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool;
 import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty;
 import static org.elasticsearch.xpack.inference.services.amazonbedrock.client.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT;
+import static org.hamcrest.Matchers.instanceOf;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.sameInstance;
 import static org.mockito.Mockito.mock;
 
 public class AmazonBedrockRequestSenderTests extends ESTestCase {
@@ -60,11 +62,37 @@ public class AmazonBedrockRequestSenderTests extends ESTestCase {
         terminate(threadPool);
     }
 
+    public void testCreateSender_UsesTheSameInstanceForRequestExecutor() throws Exception {
+        var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
+        requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
+        var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
+
+        var sender1 = createSender(senderFactory);
+        var sender2 = createSender(senderFactory);
+
+        assertThat(sender1, instanceOf(AmazonBedrockRequestSender.class));
+        assertThat(sender2, instanceOf(AmazonBedrockRequestSender.class));
+
+        assertThat(sender1, sameInstance(sender2));
+    }
+
+    public void testCreateSender_CanCallStartMultipleTimes() throws Exception {
+        var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
+        requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
+        var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
+
+        try (var sender = createSender(senderFactory)) {
+            sender.start();
+            sender.start();
+            sender.start();
+        }
+    }
+
     public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws Exception {
-        var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
         var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
         requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT));
-        try (var sender = createSender(senderFactory, requestSender)) {
+        var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
+        try (var sender = createSender(senderFactory)) {
             sender.start();
 
             var model = AmazonBedrockEmbeddingsModelTests.createModel(
@@ -92,10 +120,10 @@ public class AmazonBedrockRequestSenderTests extends ESTestCase {
     }
 
     public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws Exception {
-        var senderFactory = createSenderFactory(threadPool, Settings.EMPTY);
         var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class));
         requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text"));
-        try (var sender = createSender(senderFactory, requestSender)) {
+        var senderFactory = createSenderFactory(threadPool, Settings.EMPTY, requestSender);
+        try (var sender = createSender(senderFactory)) {
             sender.start();
 
             var model = AmazonBedrockChatCompletionModelTests.createModel(
@@ -116,14 +144,19 @@ public class AmazonBedrockRequestSenderTests extends ESTestCase {
         }
     }
 
-    public static AmazonBedrockRequestSender.Factory createSenderFactory(ThreadPool threadPool, Settings settings) {
+    public static AmazonBedrockRequestSender.Factory createSenderFactory(
+        ThreadPool threadPool,
+        Settings settings,
+        AmazonBedrockMockExecuteRequestSender requestSender
+    ) {
         return new AmazonBedrockRequestSender.Factory(
             ServiceComponentsTests.createWithSettings(threadPool, settings),
-            mockClusterServiceEmpty()
+            mockClusterServiceEmpty(),
+            requestSender
         );
     }
 
-    public static Sender createSender(AmazonBedrockRequestSender.Factory factory, AmazonBedrockExecuteOnlyRequestSender requestSender) {
-        return factory.createSender(requestSender);
+    public static Sender createSender(AmazonBedrockRequestSender.Factory factory) {
+        return factory.createSender();
     }
 }