Browse Source

[ML] Return chunks for each input to InferenceService::chunkInfer (#105447)

David Kyle 1 year ago
parent
commit
88f82b5c93
13 changed files with 211 additions and 46 deletions
  1. 2 2
      server/src/main/java/org/elasticsearch/inference/InferenceService.java
  2. 4 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java
  3. 4 0
      x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingResults.java
  4. 3 3
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java
  5. 2 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java
  6. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java
  7. 16 18
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java
  8. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java
  9. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java
  10. 17 17
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/textembedding/TextEmbeddingInternalService.java
  11. 1 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java
  12. 81 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java
  13. 78 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/textembedding/TextEmbeddingInternalServiceTests.java

+ 2 - 2
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -103,7 +103,7 @@ public interface InferenceService extends Closeable {
      * @param taskSettings Settings in the request to override the model's defaults
      * @param inputType For search, ingest etc
      * @param chunkingOptions The window and span options to apply
-     * @param listener Inference result listener
+     * @param listener Chunked Inference result listener
      */
     void chunkedInfer(
         Model model,
@@ -111,7 +111,7 @@ public interface InferenceService extends Closeable {
         Map<String, Object> taskSettings,
         InputType inputType,
         ChunkingOptions chunkingOptions,
-        ActionListener<ChunkedInferenceServiceResults> listener
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
     );
 
     /**

+ 4 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedSparseEmbeddingResults.java

@@ -37,6 +37,10 @@ public class ChunkedSparseEmbeddingResults implements ChunkedInferenceServiceRes
         this.chunkedResults = in.readCollectionAsList(ChunkedTextExpansionResults.ChunkedResult::new);
     }
 
+    public List<ChunkedTextExpansionResults.ChunkedResult> getChunkedResults() {
+        return chunkedResults;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startArray("sparse_embedding_chunk");

+ 4 - 0
x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/ChunkedTextEmbeddingResults.java

@@ -42,6 +42,10 @@ public class ChunkedTextEmbeddingResults implements ChunkedInferenceServiceResul
         );
     }
 
+    public List<org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults.EmbeddingChunk> getChunks() {
+        return chunks;
+    }
+
     @Override
     public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
         builder.startArray("text_embedding_chunk");

+ 3 - 3
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestInferenceServiceExtension.java

@@ -152,7 +152,7 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
             Map<String, Object> taskSettings,
             InputType inputType,
             ChunkingOptions chunkingOptions,
-            ActionListener<ChunkedInferenceServiceResults> listener
+            ActionListener<List<ChunkedInferenceServiceResults>> listener
         ) {
             switch (model.getConfigurations().getTaskType()) {
                 case ANY, SPARSE_EMBEDDING -> listener.onResponse(makeChunkedResults(input));
@@ -177,7 +177,7 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
             return new SparseEmbeddingResults(embeddings);
         }
 
-        private ChunkedSparseEmbeddingResults makeChunkedResults(List<String> input) {
+        private List<ChunkedInferenceServiceResults> makeChunkedResults(List<String> input) {
             var chunks = new ArrayList<ChunkedTextExpansionResults.ChunkedResult>();
             for (int i = 0; i < input.size(); i++) {
                 var tokens = new ArrayList<TextExpansionResults.WeightedToken>();
@@ -186,7 +186,7 @@ public class TestInferenceServiceExtension implements InferenceServiceExtension
                 }
                 chunks.add(new ChunkedTextExpansionResults.ChunkedResult(input.get(i), tokens));
             }
-            return new ChunkedSparseEmbeddingResults(chunks);
+            return List.of(new ChunkedSparseEmbeddingResults(chunks));
         }
 
         @Override

+ 2 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

@@ -63,7 +63,7 @@ public abstract class SenderService implements InferenceService {
         Map<String, Object> taskSettings,
         InputType inputType,
         ChunkingOptions chunkingOptions,
-        ActionListener<ChunkedInferenceServiceResults> listener
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
         init();
         doChunkedInfer(model, input, taskSettings, inputType, chunkingOptions, listener);
@@ -83,7 +83,7 @@ public abstract class SenderService implements InferenceService {
         Map<String, Object> taskSettings,
         InputType inputType,
         ChunkingOptions chunkingOptions,
-        ActionListener<ChunkedInferenceServiceResults> listener
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
     );
 
     @Override

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java

@@ -181,7 +181,7 @@ public class CohereService extends SenderService {
         Map<String, Object> taskSettings,
         InputType inputType,
         ChunkingOptions chunkingOptions,
-        ActionListener<ChunkedInferenceServiceResults> listener
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
         listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
     }

+ 16 - 18
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java

@@ -44,6 +44,7 @@ import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper;
 import org.elasticsearch.xpack.inference.services.ServiceUtils;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -285,7 +286,7 @@ public class ElserInternalService implements InferenceService {
         Map<String, Object> taskSettings,
         InputType inputType,
         ChunkingOptions chunkingOptions,
-        ActionListener<ChunkedInferenceServiceResults> listener
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
         try {
             checkCompatibleTaskType(model.getConfigurations().getTaskType());
@@ -355,25 +356,22 @@ public class ElserInternalService implements InferenceService {
         return ElserMlNodeTaskSettings.DEFAULT;
     }
 
-    private ChunkedSparseEmbeddingResults translateChunkedResults(List<InferenceResults> inferenceResults) {
-        if (inferenceResults.size() != 1) {
-            throw new ElasticsearchStatusException(
-                "Expected exactly one chunked sparse embedding result",
-                RestStatus.INTERNAL_SERVER_ERROR
-            );
-        }
+    private List<ChunkedInferenceServiceResults> translateChunkedResults(List<InferenceResults> inferenceResults) {
+        var translated = new ArrayList<ChunkedInferenceServiceResults>();
 
-        if (inferenceResults.get(0) instanceof ChunkedTextExpansionResults mlChunkedResult) {
-            return ChunkedSparseEmbeddingResults.ofMlResult(mlChunkedResult);
-        } else {
-            throw new ElasticsearchStatusException(
-                "Expected a chunked inference [{}] received [{}]",
-                RestStatus.INTERNAL_SERVER_ERROR,
-                ChunkedTextExpansionResults.NAME,
-                inferenceResults.get(0).getWriteableName()
-            );
+        for (var inferenceResult : inferenceResults) {
+            if (inferenceResult instanceof ChunkedTextExpansionResults mlChunkedResult) {
+                translated.add(ChunkedSparseEmbeddingResults.ofMlResult(mlChunkedResult));
+            } else {
+                throw new ElasticsearchStatusException(
+                    "Expected a chunked inference [{}] received [{}]",
+                    RestStatus.INTERNAL_SERVER_ERROR,
+                    ChunkedTextExpansionResults.NAME,
+                    inferenceResult.getWriteableName()
+                );
+            }
         }
-
+        return translated;
     }
 
     @Override

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseService.java

@@ -126,7 +126,7 @@ public abstract class HuggingFaceBaseService extends SenderService {
         Map<String, Object> taskSettings,
         InputType inputType,
         ChunkingOptions chunkingOptions,
-        ActionListener<ChunkedInferenceServiceResults> listener
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
         listener.onFailure(new UnsupportedOperationException("Chunked inference not implemented for Hugging Face"));
     }

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java

@@ -190,7 +190,7 @@ public class OpenAiService extends SenderService {
         Map<String, Object> taskSettings,
         InputType inputType,
         ChunkingOptions chunkingOptions,
-        ActionListener<ChunkedInferenceServiceResults> listener
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
         listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
     }

+ 17 - 17
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/textembedding/TextEmbeddingInternalService.java

@@ -36,13 +36,13 @@ import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.action.StopTrainedModelDeploymentAction;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig;
 import org.elasticsearch.xpack.core.ml.inference.TrainedModelInput;
-import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate;
 import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TokenizationConfigUpdate;
 import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;
 
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -241,7 +241,7 @@ public class TextEmbeddingInternalService implements InferenceService {
         Map<String, Object> taskSettings,
         InputType inputType,
         ChunkingOptions chunkingOptions,
-        ActionListener<ChunkedInferenceServiceResults> listener
+        ActionListener<List<ChunkedInferenceServiceResults>> listener
     ) {
         try {
             checkCompatibleTaskType(model.getConfigurations().getTaskType());
@@ -345,23 +345,23 @@ public class TextEmbeddingInternalService implements InferenceService {
         }
     }
 
-    private ChunkedTextEmbeddingResults translateChunkedResults(List<InferenceResults> inferenceResults) {
-        if (inferenceResults.size() != 1) {
-            throw new ElasticsearchStatusException("Expected exactly one chunked text embedding result", RestStatus.INTERNAL_SERVER_ERROR);
-        }
+    private List<ChunkedInferenceServiceResults> translateChunkedResults(List<InferenceResults> inferenceResults) {
+        var translated = new ArrayList<ChunkedInferenceServiceResults>();
 
-        if (inferenceResults.get(
-            0
-        ) instanceof org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults mlChunkedResult) {
-            return ChunkedTextEmbeddingResults.ofMlResult(mlChunkedResult);
-        } else {
-            throw new ElasticsearchStatusException(
-                "Expected a chunked inference [{}] received [{}]",
-                RestStatus.INTERNAL_SERVER_ERROR,
-                ChunkedTextExpansionResults.NAME,
-                inferenceResults.get(0).getWriteableName()
-            );
+        for (var inferenceResult : inferenceResults) {
+            if (inferenceResult instanceof org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults mlChunkedResult) {
+                translated.add(ChunkedTextEmbeddingResults.ofMlResult(mlChunkedResult));
+            } else {
+                throw new ElasticsearchStatusException(
+                    "Expected a chunked inference [{}] received [{}]",
+                    RestStatus.INTERNAL_SERVER_ERROR,
+                    ChunkedTextEmbeddingResults.NAME,
+                    inferenceResult.getWriteableName()
+                );
+            }
         }
+
+        return translated;
     }
 
     @Override

+ 1 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java

@@ -121,7 +121,7 @@ public class SenderServiceTests extends ESTestCase {
             Map<String, Object> taskSettings,
             InputType inputType,
             ChunkingOptions chunkingOptions,
-            ActionListener<ChunkedInferenceServiceResults> listener
+            ActionListener<List<ChunkedInferenceServiceResults>> listener
         ) {
 
         }

+ 81 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java

@@ -11,19 +11,37 @@ package org.elasticsearch.xpack.inference.services.elser;
 
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.inference.InferenceServiceExtension;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResultsTests;
 
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
 
 import static org.hamcrest.Matchers.containsString;
+import static org.hamcrest.Matchers.hasSize;
+import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.same;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class ElserInternalServiceTests extends ESTestCase {
 
@@ -307,6 +325,69 @@ public class ElserInternalServiceTests extends ESTestCase {
         }
     }
 
+    @SuppressWarnings("unchecked")
+    public void testChunkInfer() {
+        var mlTrainedModelResults = new ArrayList<InferenceResults>();
+        mlTrainedModelResults.add(ChunkedTextExpansionResultsTests.createRandomResults());
+        mlTrainedModelResults.add(ChunkedTextExpansionResultsTests.createRandomResults());
+        var response = new InferTrainedModelDeploymentAction.Response(mlTrainedModelResults);
+
+        ThreadPool threadpool = new TestThreadPool("test");
+        Client client = mock(Client.class);
+        when(client.threadPool()).thenReturn(threadpool);
+        doAnswer(invocationOnMock -> {
+            var listener = (ActionListener<InferTrainedModelDeploymentAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(response);
+            return null;
+        }).when(client)
+            .execute(
+                same(InferTrainedModelDeploymentAction.INSTANCE),
+                any(InferTrainedModelDeploymentAction.Request.class),
+                any(ActionListener.class)
+            );
+
+        var model = new ElserInternalModel(
+            "foo",
+            TaskType.SPARSE_EMBEDDING,
+            "elser",
+            new ElserInternalServiceSettings(1, 1, "elser"),
+            new ElserMlNodeTaskSettings()
+        );
+        var service = createService(client);
+
+        var gotResults = new AtomicBoolean();
+        var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
+            assertThat(chunkedResponse, hasSize(2));
+            assertThat(chunkedResponse.get(0), instanceOf(ChunkedSparseEmbeddingResults.class));
+            var result1 = (ChunkedSparseEmbeddingResults) chunkedResponse.get(0);
+            assertEquals(
+                ((org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults) mlTrainedModelResults.get(0)).getChunks(),
+                result1.getChunkedResults()
+            );
+            assertThat(chunkedResponse.get(1), instanceOf(ChunkedSparseEmbeddingResults.class));
+            var result2 = (ChunkedSparseEmbeddingResults) chunkedResponse.get(1);
+            assertEquals(
+                ((org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextExpansionResults) mlTrainedModelResults.get(1)).getChunks(),
+                result2.getChunkedResults()
+            );
+            gotResults.set(true);
+        }, ESTestCase::fail);
+
+        service.chunkedInfer(
+            model,
+            List.of("foo", "bar"),
+            Map.of(),
+            InputType.SEARCH,
+            new ChunkingOptions(null, null),
+            ActionListener.runAfter(resultsListener, () -> terminate(threadpool))
+        );
+
+        if (gotResults.get() == false) {
+            terminate(threadpool);
+        }
+        assertTrue("Listener not called", gotResults.get());
+    }
+
     private ElserInternalService createService(Client client) {
         var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
         return new ElserInternalService(context);

+ 78 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/textembedding/TextEmbeddingInternalServiceTests.java

@@ -12,22 +12,38 @@ package org.elasticsearch.xpack.inference.services.textembedding;
 import org.elasticsearch.ElasticsearchStatusException;
 import org.elasticsearch.action.ActionListener;
 import org.elasticsearch.client.internal.Client;
+import org.elasticsearch.inference.ChunkedInferenceServiceResults;
+import org.elasticsearch.inference.ChunkingOptions;
+import org.elasticsearch.inference.InferenceResults;
 import org.elasticsearch.inference.InferenceServiceExtension;
+import org.elasticsearch.inference.InputType;
 import org.elasticsearch.inference.Model;
 import org.elasticsearch.inference.ModelConfigurations;
 import org.elasticsearch.inference.TaskType;
 import org.elasticsearch.test.ESTestCase;
+import org.elasticsearch.threadpool.TestThreadPool;
+import org.elasticsearch.threadpool.ThreadPool;
+import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
+import org.elasticsearch.xpack.core.ml.action.InferTrainedModelDeploymentAction;
+import org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResultsTests;
 import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Random;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicBoolean;
 
+import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.instanceOf;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.same;
+import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 public class TextEmbeddingInternalServiceTests extends ESTestCase {
 
@@ -325,6 +341,68 @@ public class TextEmbeddingInternalServiceTests extends ESTestCase {
         }
     }
 
+    @SuppressWarnings("unchecked")
+    public void testChunkInfer() {
+        var mlTrainedModelResults = new ArrayList<InferenceResults>();
+        mlTrainedModelResults.add(ChunkedTextEmbeddingResultsTests.createRandomResults());
+        mlTrainedModelResults.add(ChunkedTextEmbeddingResultsTests.createRandomResults());
+        var response = new InferTrainedModelDeploymentAction.Response(mlTrainedModelResults);
+
+        ThreadPool threadpool = new TestThreadPool("test");
+        Client client = mock(Client.class);
+        when(client.threadPool()).thenReturn(threadpool);
+        doAnswer(invocationOnMock -> {
+            var listener = (ActionListener<InferTrainedModelDeploymentAction.Response>) invocationOnMock.getArguments()[2];
+            listener.onResponse(response);
+            return null;
+        }).when(client)
+            .execute(
+                same(InferTrainedModelDeploymentAction.INSTANCE),
+                any(InferTrainedModelDeploymentAction.Request.class),
+                any(ActionListener.class)
+            );
+
+        var model = new MultilingualE5SmallModel(
+            "foo",
+            TaskType.TEXT_EMBEDDING,
+            "e5",
+            new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform")
+        );
+        var service = createService(client);
+
+        var gotResults = new AtomicBoolean();
+        var resultsListener = ActionListener.<List<ChunkedInferenceServiceResults>>wrap(chunkedResponse -> {
+            assertThat(chunkedResponse, hasSize(2));
+            assertThat(chunkedResponse.get(0), instanceOf(ChunkedTextEmbeddingResults.class));
+            var result1 = (ChunkedTextEmbeddingResults) chunkedResponse.get(0);
+            assertEquals(
+                ((org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults) mlTrainedModelResults.get(0)).getChunks(),
+                result1.getChunks()
+            );
+            assertThat(chunkedResponse.get(1), instanceOf(ChunkedTextEmbeddingResults.class));
+            var result2 = (ChunkedTextEmbeddingResults) chunkedResponse.get(1);
+            assertEquals(
+                ((org.elasticsearch.xpack.core.ml.inference.results.ChunkedTextEmbeddingResults) mlTrainedModelResults.get(1)).getChunks(),
+                result2.getChunks()
+            );
+            gotResults.set(true);
+        }, ESTestCase::fail);
+
+        service.chunkedInfer(
+            model,
+            List.of("foo", "bar"),
+            Map.of(),
+            InputType.SEARCH,
+            new ChunkingOptions(null, null),
+            ActionListener.runAfter(resultsListener, () -> terminate(threadpool))
+        );
+
+        if (gotResults.get() == false) {
+            terminate(threadpool);
+        }
+        assertTrue("Listener not called", gotResults.get());
+    }
+
     private TextEmbeddingInternalService createService(Client client) {
         var context = new InferenceServiceExtension.InferenceServiceFactoryContext(client);
         return new TextEmbeddingInternalService(context);