瀏覽代碼

[ML] Add stream flag to inference providers (#113424) (#113628)

Pass the stream flag from the REST request through to the inference
providers via the InferenceInputs.

Co-authored-by: Elastic Machine <elasticmachine@users.noreply.github.com>
Pat Whelan 1 年之前
父節點
當前提交
886280c7cb
共有 29 個文件被更改,包括 91 次插入24 次删除
  1. 2 0
      server/src/main/java/org/elasticsearch/inference/InferenceService.java
  2. 1 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java
  3. 1 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java
  4. 1 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java
  5. 1 0
      x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java
  6. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java
  7. 1 1
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpUtils.java
  8. 12 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java
  9. 15 6
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java
  10. 3 2
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java
  11. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java
  12. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java
  13. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java
  14. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java
  15. 1 0
      x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java
  16. 9 12
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java
  17. 4 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java
  18. 2 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java
  19. 3 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java
  20. 3 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java
  21. 6 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java
  22. 2 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java
  23. 4 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java
  24. 1 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java
  25. 2 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java
  26. 3 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java
  27. 2 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java
  28. 3 0
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java
  29. 4 1
      x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java

+ 2 - 0
server/src/main/java/org/elasticsearch/inference/InferenceService.java

@@ -85,6 +85,7 @@ public interface InferenceService extends Closeable {
      * @param model        The model
      * @param query        Inference query, mainly for re-ranking
      * @param input        Inference input
+     * @param stream       Stream inference results
      * @param taskSettings Settings in the request to override the model's defaults
      * @param inputType    For search, ingest etc
      * @param timeout      The timeout for the request
@@ -94,6 +95,7 @@ public interface InferenceService extends Closeable {
         Model model,
         @Nullable String query,
         List<String> input,
+        boolean stream,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,

+ 1 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java

@@ -94,6 +94,7 @@ public class TestDenseInferenceServiceExtension implements InferenceServiceExten
             Model model,
             @Nullable String query,
             List<String> input,
+            boolean stream,
             Map<String, Object> taskSettings,
             InputType inputType,
             TimeValue timeout,

+ 1 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java

@@ -85,6 +85,7 @@ public class TestRerankingServiceExtension implements InferenceServiceExtension
             Model model,
             @Nullable String query,
             List<String> input,
+            boolean stream,
             Map<String, Object> taskSettings,
             InputType inputType,
             TimeValue timeout,

+ 1 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java

@@ -88,6 +88,7 @@ public class TestSparseInferenceServiceExtension implements InferenceServiceExte
             Model model,
             @Nullable String query,
             List<String> input,
+            boolean stream,
             Map<String, Object> taskSettings,
             InputType inputType,
             TimeValue timeout,

+ 1 - 0
x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java

@@ -85,6 +85,7 @@ public class TestStreamingCompletionServiceExtension implements InferenceService
             Model model,
             String query,
             List<String> input,
+            boolean stream,
             Map<String, Object> taskSettings,
             InputType inputType,
             TimeValue timeout,

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java

@@ -114,6 +114,7 @@ public class TransportInferenceAction extends HandledTransportAction<InferenceAc
                 model,
                 request.getQuery(),
                 request.getInput(),
+                request.isStreaming(),
                 request.getTaskSettings(),
                 request.getInputType(),
                 request.getInferenceTimeout(),

+ 1 - 1
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpUtils.java

@@ -46,7 +46,7 @@ public class HttpUtils {
     }
 
     public static void checkForEmptyBody(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) {
-        if (result.isBodyEmpty()) {
+        if (result.isBodyEmpty() && (request.isStreaming() == false)) {
             String message = format("Response body was empty for request from inference entity id [%s]", request.getInferenceEntityId());
             throttlerManager.warn(logger, message);
             throw new IllegalStateException(message);

+ 12 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java

@@ -21,13 +21,23 @@ public class DocumentsOnlyInput extends InferenceInputs {
     }
 
     private final List<String> input;
+    private final boolean stream;
 
-    public DocumentsOnlyInput(List<String> chunks) {
+    public DocumentsOnlyInput(List<String> input) {
+        this(input, false);
+    }
+
+    public DocumentsOnlyInput(List<String> input, boolean stream) {
         super();
-        this.input = Objects.requireNonNull(chunks);
+        this.input = Objects.requireNonNull(input);
+        this.stream = stream;
     }
 
     public List<String> getInputs() {
         return this.input;
     }
+
+    public boolean stream() {
+        return stream;
+    }
 }

+ 15 - 6
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java

@@ -21,6 +21,19 @@ public class QueryAndDocsInputs extends InferenceInputs {
     }
 
     private final String query;
+    private final List<String> chunks;
+    private final boolean stream;
+
+    public QueryAndDocsInputs(String query, List<String> chunks) {
+        this(query, chunks, false);
+    }
+
+    public QueryAndDocsInputs(String query, List<String> chunks, boolean stream) {
+        super();
+        this.query = Objects.requireNonNull(query);
+        this.chunks = Objects.requireNonNull(chunks);
+        this.stream = stream;
+    }
 
     public String getQuery() {
         return query;
@@ -30,12 +43,8 @@ public class QueryAndDocsInputs extends InferenceInputs {
         return chunks;
     }
 
-    List<String> chunks;
-
-    public QueryAndDocsInputs(String query, List<String> chunks) {
-        super();
-        this.query = Objects.requireNonNull(query);
-        this.chunks = Objects.requireNonNull(chunks);
+    public boolean stream() {
+        return stream;
     }
 
 }

+ 3 - 2
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java

@@ -51,6 +51,7 @@ public abstract class SenderService implements InferenceService {
         Model model,
         @Nullable String query,
         List<String> input,
+        boolean stream,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,
@@ -58,9 +59,9 @@ public abstract class SenderService implements InferenceService {
     ) {
         init();
         if (query != null) {
-            doInfer(model, new QueryAndDocsInputs(query, input), taskSettings, inputType, timeout, listener);
+            doInfer(model, new QueryAndDocsInputs(query, input, stream), taskSettings, inputType, timeout, listener);
         } else {
-            doInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, timeout, listener);
+            doInfer(model, new DocumentsOnlyInput(input, stream), taskSettings, inputType, timeout, listener);
         }
     }
 

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java

@@ -659,6 +659,7 @@ public final class ServiceUtils {
             model,
             null,
             List.of(TEST_EMBEDDING_INPUT),
+            false,
             Map.of(),
             InputType.INGEST,
             InferenceAction.Request.DEFAULT_TIMEOUT,

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java

@@ -309,6 +309,7 @@ public class AlibabaCloudSearchService extends SenderService {
             model,
             query,
             List.of(input),
+            false,
             Map.of(),
             InputType.INGEST,
             DEFAULT_TIMEOUT,

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java

@@ -323,6 +323,7 @@ public class ElasticsearchInternalService extends BaseElasticsearchInternalServi
         Model model,
         @Nullable String query,
         List<String> input,
+        boolean stream,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java

@@ -149,6 +149,7 @@ public class ElserInternalService extends BaseElasticsearchInternalService {
         Model model,
         @Nullable String query,
         List<String> inputs,
+        boolean stream,
         Map<String, Object> taskSettings,
         InputType inputType,
         TimeValue timeout,

+ 1 - 0
x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java

@@ -30,6 +30,7 @@ public class SimpleServiceIntegrationValidator implements ServiceIntegrationVali
             model,
             model.getTaskType().equals(TaskType.RERANK) ? QUERY : null,
             TEST_INPUT,
+            false,
             Map.of(),
             InputType.INGEST,
             InferenceAction.Request.DEFAULT_TIMEOUT,

+ 9 - 12
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ServiceUtilsTests.java

@@ -49,6 +49,7 @@ import static org.hamcrest.Matchers.hasSize;
 import static org.hamcrest.Matchers.is;
 import static org.hamcrest.Matchers.nullValue;
 import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.anyBoolean;
 import static org.mockito.Mockito.doAnswer;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -855,12 +856,11 @@ public class ServiceUtilsTests extends ESTestCase {
         when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
 
         doAnswer(invocation -> {
-            @SuppressWarnings("unchecked")
-            ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
             listener.onResponse(new InferenceTextEmbeddingFloatResults(List.of()));
 
             return Void.TYPE;
-        }).when(service).infer(any(), any(), any(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
 
         PlainActionFuture<Integer> listener = new PlainActionFuture<>();
         getEmbeddingSize(model, service, listener);
@@ -878,12 +878,11 @@ public class ServiceUtilsTests extends ESTestCase {
         when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING);
 
         doAnswer(invocation -> {
-            @SuppressWarnings("unchecked")
-            ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
             listener.onResponse(new InferenceTextEmbeddingByteResults(List.of()));
 
             return Void.TYPE;
-        }).when(service).infer(any(), any(), any(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
 
         PlainActionFuture<Integer> listener = new PlainActionFuture<>();
         getEmbeddingSize(model, service, listener);
@@ -903,12 +902,11 @@ public class ServiceUtilsTests extends ESTestCase {
         var textEmbedding = TextEmbeddingResultsTests.createRandomResults();
 
         doAnswer(invocation -> {
-            @SuppressWarnings("unchecked")
-            ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
             listener.onResponse(textEmbedding);
 
             return Void.TYPE;
-        }).when(service).infer(any(), any(), any(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
 
         PlainActionFuture<Integer> listener = new PlainActionFuture<>();
         getEmbeddingSize(model, service, listener);
@@ -927,12 +925,11 @@ public class ServiceUtilsTests extends ESTestCase {
         var textEmbedding = InferenceTextEmbeddingByteResultsTests.createRandomResults();
 
         doAnswer(invocation -> {
-            @SuppressWarnings("unchecked")
-            ActionListener<InferenceServiceResults> listener = (ActionListener<InferenceServiceResults>) invocation.getArguments()[6];
+            ActionListener<InferenceServiceResults> listener = invocation.getArgument(7);
             listener.onResponse(textEmbedding);
 
             return Void.TYPE;
-        }).when(service).infer(any(), any(), any(), any(), any(), any(), any());
+        }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any());
 
         PlainActionFuture<Integer> listener = new PlainActionFuture<>();
         getEmbeddingSize(model, service, listener);

+ 4 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java

@@ -671,6 +671,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -721,6 +722,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
                     model,
                     null,
                     List.of("abc"),
+                    false,
                     new HashMap<>(),
                     InputType.INGEST,
                     InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -762,6 +764,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
                     model,
                     null,
                     List.of("abc"),
+                    false,
                     new HashMap<>(),
                     InputType.INGEST,
                     InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -1025,6 +1028,7 @@ public class AmazonBedrockServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 2 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java

@@ -452,6 +452,7 @@ public class AnthropicServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -506,6 +507,7 @@ public class AnthropicServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("input"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 3 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java

@@ -825,6 +825,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -953,6 +954,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -1003,6 +1005,7 @@ public class AzureAiStudioServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 3 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java

@@ -601,6 +601,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -656,6 +657,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -1051,6 +1053,7 @@ public class AzureOpenAiServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 6 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java

@@ -622,6 +622,7 @@ public class CohereServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -689,6 +690,7 @@ public class CohereServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -932,6 +934,7 @@ public class CohereServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -991,6 +994,7 @@ public class CohereServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -1064,6 +1068,7 @@ public class CohereServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -1135,6 +1140,7 @@ public class CohereServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.UNSPECIFIED,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 2 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java

@@ -346,6 +346,7 @@ public class ElasticInferenceServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -397,6 +398,7 @@ public class ElasticInferenceServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("input text"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 4 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java

@@ -503,6 +503,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -578,6 +579,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("input"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -634,6 +636,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of(input),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -774,6 +777,7 @@ public class GoogleAiStudioServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 1 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java

@@ -69,6 +69,7 @@ public class HuggingFaceBaseServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 2 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java

@@ -438,6 +438,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -481,6 +482,7 @@ public class HuggingFaceServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 3 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java

@@ -409,6 +409,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -465,6 +466,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of(input),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -588,6 +590,7 @@ public class IbmWatsonxServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 2 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java

@@ -402,6 +402,7 @@ public class MistralServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -526,6 +527,7 @@ public class MistralServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 3 - 0
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java

@@ -936,6 +936,7 @@ public class OpenAiServiceTests extends ESTestCase {
                 mockModel,
                 null,
                 List.of(""),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -990,6 +991,7 @@ public class OpenAiServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,
@@ -1469,6 +1471,7 @@ public class OpenAiServiceTests extends ESTestCase {
                 model,
                 null,
                 List.of("abc"),
+                false,
                 new HashMap<>(),
                 InputType.INGEST,
                 InferenceAction.Request.DEFAULT_TIMEOUT,

+ 4 - 1
x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java

@@ -64,6 +64,7 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
                 eq(mockModel),
                 eq(null),
                 eq(TEST_INPUT),
+                eq(false),
                 eq(Map.of()),
                 eq(InputType.INGEST),
                 eq(InferenceAction.Request.DEFAULT_TIMEOUT),
@@ -94,7 +95,7 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
 
     private void mockSuccessfulCallToService(String query, InferenceServiceResults result) {
         doAnswer(ans -> {
-            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(6);
+            ActionListener<InferenceServiceResults> responseListener = ans.getArgument(7);
             responseListener.onResponse(result);
             return null;
         }).when(mockInferenceService)
@@ -102,6 +103,7 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
                 eq(mockModel),
                 eq(query),
                 eq(TEST_INPUT),
+                eq(false),
                 eq(Map.of()),
                 eq(InputType.INGEST),
                 eq(InferenceAction.Request.DEFAULT_TIMEOUT),
@@ -117,6 +119,7 @@ public class SimpleServiceIntegrationValidatorTests extends ESTestCase {
             eq(mockModel),
             eq(withQuery ? TEST_QUERY : null),
             eq(TEST_INPUT),
+            eq(false),
             eq(Map.of()),
             eq(InputType.INGEST),
             eq(InferenceAction.Request.DEFAULT_TIMEOUT),